diff --git a/.cargo/config.toml b/.cargo/config.toml index 6f84a73999..a3b2334182 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -19,3 +19,8 @@ rustflags = [ # pthread_key_create() destructors and segfault after a DSO unloading [target.'cfg(all(target_env = "gnu", not(target_os = "windows")))'] rustflags = ["-C", "link-args=-Wl,-z,nodelete"] + +# Temporary local llm_adapter override. +# Uncomment when verifying AFFiNE against the sibling llm_adapter workspace. +# [patch.crates-io] +# llm_adapter = { path = "../llm_adapter" } diff --git a/.docker/selfhost/schema.json b/.docker/selfhost/schema.json index ef819a6089..c625f9e778 100644 --- a/.docker/selfhost/schema.json +++ b/.docker/selfhost/schema.json @@ -197,8 +197,8 @@ "properties": { "SMTP.name": { "type": "string", - "description": "Name of the email server (e.g. your domain name)\n@default \"AFFiNE Server\"\n@environment `MAILER_SERVERNAME`", - "default": "AFFiNE Server" + "description": "Hostname used for SMTP HELO/EHLO (e.g. mail.example.com). Leave empty to use the system hostname.\n@default \"\"\n@environment `MAILER_SERVERNAME`", + "default": "" }, "SMTP.host": { "type": "string", @@ -237,8 +237,8 @@ }, "fallbackSMTP.name": { "type": "string", - "description": "Name of the fallback email server (e.g. your domain name)\n@default \"AFFiNE Server\"", - "default": "AFFiNE Server" + "description": "Hostname used for fallback SMTP HELO/EHLO (e.g. mail.example.com). Leave empty to use the system hostname.\n@default \"\"", + "default": "" }, "fallbackSMTP.host": { "type": "string", @@ -971,7 +971,7 @@ }, "scenarios": { "type": "object", - "description": "Use custom models in scenarios and override default settings.\n@default {\"override_enabled\":false,\"scenarios\":{\"audio_transcribing\":\"gemini-2.5-flash\",\"chat\":\"gemini-2.5-flash\",\"embedding\":\"gemini-embedding-001\",\"image\":\"gpt-image-1\",\"rerank\":\"gpt-4.1\",\"coding\":\"claude-sonnet-4-5@20250929\",\"complex_text_generation\":\"gpt-4o-2024-08-06\",\"quick_decision_making\":\"gpt-5-mini\",\"quick_text_generation\":\"gemini-2.5-flash\",\"polish_and_summarize\":\"gemini-2.5-flash\"}}", + "description": "Use custom models in scenarios and override default settings.\n@default {\"override_enabled\":false,\"scenarios\":{\"audio_transcribing\":\"gemini-2.5-flash\",\"chat\":\"gemini-2.5-flash\",\"embedding\":\"gemini-embedding-001\",\"image\":\"gpt-image-1\",\"coding\":\"claude-sonnet-4-5@20250929\",\"complex_text_generation\":\"gpt-5-mini\",\"quick_decision_making\":\"gpt-5-mini\",\"quick_text_generation\":\"gemini-2.5-flash\",\"polish_and_summarize\":\"gemini-2.5-flash\"}}", "default": { "override_enabled": false, "scenarios": { @@ -979,15 +979,24 @@ "chat": "gemini-2.5-flash", "embedding": "gemini-embedding-001", "image": "gpt-image-1", - "rerank": "gpt-4.1", "coding": "claude-sonnet-4-5@20250929", - "complex_text_generation": "gpt-4o-2024-08-06", + "complex_text_generation": "gpt-5-mini", "quick_decision_making": "gpt-5-mini", "quick_text_generation": "gemini-2.5-flash", "polish_and_summarize": "gemini-2.5-flash" } } }, + "providers.profiles": { + "type": "array", + "description": "The profile list for copilot providers.\n@default []", + "default": [] + }, + "providers.defaults": { + "type": "object", + "description": "The default provider ids for model output types and global fallback.\n@default {}", + "default": {} + }, "providers.openai": { "type": "object", "description": "The config for the openai provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.openai.com/v1\"}\n@link https://github.com/openai/openai-node", diff --git a/.github/actions/build-rust/action.yml b/.github/actions/build-rust/action.yml index b3c81f97f5..0a9bb6dba0 100644 --- a/.github/actions/build-rust/action.yml +++ b/.github/actions/build-rust/action.yml @@ -50,8 +50,14 @@ runs: # https://github.com/tree-sitter/tree-sitter/issues/4186 # pass -D_BSD_SOURCE to clang to fix the tree-sitter build issue run: | - echo "CC=clang -D_BSD_SOURCE" >> "$GITHUB_ENV" - echo "TARGET_CC=clang -D_BSD_SOURCE" >> "$GITHUB_ENV" + if [[ "${{ inputs.target }}" == "aarch64-unknown-linux-gnu" ]]; then + # napi cross-toolchain 1.0.3 headers miss AT_HWCAP2 in elf.h + echo "CC=clang -D_BSD_SOURCE -DAT_HWCAP2=26" >> "$GITHUB_ENV" + echo "TARGET_CC=clang -D_BSD_SOURCE -DAT_HWCAP2=26" >> "$GITHUB_ENV" + else + echo "CC=clang -D_BSD_SOURCE" >> "$GITHUB_ENV" + echo "TARGET_CC=clang -D_BSD_SOURCE" >> "$GITHUB_ENV" + fi - name: Cache cargo uses: Swatinem/rust-cache@v2 diff --git a/.github/actions/setup-node/action.yml b/.github/actions/setup-node/action.yml index 78ce827954..6a527a9d53 100644 --- a/.github/actions/setup-node/action.yml +++ b/.github/actions/setup-node/action.yml @@ -53,7 +53,7 @@ runs: fi - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version-file: '.nvmrc' registry-url: https://npm.pkg.github.com @@ -93,7 +93,7 @@ runs: run: node -e "const p = $(yarn config cacheFolder --json).effective; console.log('yarn_global_cache=' + p)" >> $GITHUB_OUTPUT - name: Cache non-full yarn cache on Linux - uses: actions/cache@v4 + uses: actions/cache@v5 if: ${{ inputs.full-cache != 'true' && runner.os == 'Linux' }} with: path: | @@ -105,7 +105,7 @@ runs: # and the decompression performance on Windows is very terrible # so we reduce the number of cached files on non-Linux systems by remove node_modules from cache path. - name: Cache non-full yarn cache on non-Linux - uses: actions/cache@v4 + uses: actions/cache@v5 if: ${{ inputs.full-cache != 'true' && runner.os != 'Linux' }} with: path: | @@ -113,7 +113,7 @@ runs: key: node_modules-cache-${{ github.job }}-${{ runner.os }}-${{ runner.arch }}-${{ steps.system-info.outputs.name }}-${{ steps.system-info.outputs.release }}-${{ steps.system-info.outputs.version }} - name: Cache full yarn cache on Linux - uses: actions/cache@v4 + uses: actions/cache@v5 if: ${{ inputs.full-cache == 'true' && runner.os == 'Linux' }} with: path: | @@ -122,7 +122,7 @@ runs: key: node_modules-cache-full-${{ runner.os }}-${{ runner.arch }}-${{ steps.system-info.outputs.name }}-${{ steps.system-info.outputs.release }}-${{ steps.system-info.outputs.version }} - name: Cache full yarn cache on non-Linux - uses: actions/cache@v4 + uses: actions/cache@v5 if: ${{ inputs.full-cache == 'true' && runner.os != 'Linux' }} with: path: | @@ -154,7 +154,7 @@ runs: # Note: Playwright's cache directory is hard coded because that's what it # says to do in the docs. There doesn't appear to be a command that prints # it out for us. - - uses: actions/cache@v4 + - uses: actions/cache@v5 id: playwright-cache if: ${{ inputs.playwright-install == 'true' }} with: @@ -189,7 +189,7 @@ runs: run: | echo "version=$(yarn why --json electron | grep -h 'workspace:.' | jq --raw-output '.children[].locator' | sed -e 's/@playwright\/test@.*://' | head -n 1)" >> $GITHUB_OUTPUT - - uses: actions/cache@v4 + - uses: actions/cache@v5 id: electron-cache if: ${{ inputs.electron-install == 'true' }} with: diff --git a/.github/helm/affine/charts/front/values.yaml b/.github/helm/affine/charts/front/values.yaml index 08933d27d1..cc4ac4b6bb 100644 --- a/.github/helm/affine/charts/front/values.yaml +++ b/.github/helm/affine/charts/front/values.yaml @@ -31,10 +31,10 @@ podSecurityContext: resources: limits: cpu: '1' - memory: 4Gi + memory: 6Gi requests: cpu: '1' - memory: 2Gi + memory: 4Gi probe: initialDelaySeconds: 20 diff --git a/.github/workflows/auto-labeler.yml b/.github/workflows/auto-labeler.yml index 010a7eadf3..f45b177645 100644 --- a/.github/workflows/auto-labeler.yml +++ b/.github/workflows/auto-labeler.yml @@ -13,5 +13,5 @@ jobs: pull-requests: write runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/labeler@v5 + - uses: actions/checkout@v6 + - uses: actions/labeler@v6 diff --git a/.github/workflows/build-images.yml b/.github/workflows/build-images.yml index dd5d84fea3..01a2bc2cb9 100644 --- a/.github/workflows/build-images.yml +++ b/.github/workflows/build-images.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest environment: ${{ inputs.build-type }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -57,7 +57,7 @@ jobs: runs-on: ubuntu-latest environment: ${{ inputs.build-type }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -89,7 +89,7 @@ jobs: runs-on: ubuntu-latest environment: ${{ inputs.build-type }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -118,7 +118,7 @@ jobs: build-server-native: name: Build Server native - ${{ matrix.targets.name }} - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 environment: ${{ inputs.build-type }} strategy: fail-fast: false @@ -132,7 +132,7 @@ jobs: file: server-native.armv7.node steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -166,7 +166,7 @@ jobs: needs: - build-server-native steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -202,7 +202,7 @@ jobs: - build-mobile - build-admin steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Download server dist uses: actions/download-artifact@v4 with: @@ -222,7 +222,7 @@ jobs: # setup node without cache configuration # Prisma cache is not compatible with docker build cache - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version-file: '.nvmrc' registry-url: https://npm.pkg.github.com diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 0c072cbf3d..dde4f784bb 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -46,7 +46,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Initialize CodeQL uses: github/codeql-action/init@v3 @@ -67,9 +67,9 @@ jobs: name: Lint runs-on: ubuntu-24.04-arm steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Go (for actionlint) - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: 'stable' - name: Install actionlint @@ -111,7 +111,7 @@ jobs: env: NODE_OPTIONS: --max-old-space-size=14384 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -138,7 +138,7 @@ jobs: outputs: run-rust: ${{ steps.rust-filter.outputs.rust }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: rust-filter @@ -159,7 +159,7 @@ jobs: needs: - rust-test-filter steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: ./.github/actions/build-rust with: target: x86_64-unknown-linux-gnu @@ -182,7 +182,7 @@ jobs: needs: - build-server-native steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -212,7 +212,7 @@ jobs: name: Check yarn binary runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Run check run: | set -euo pipefail @@ -226,9 +226,9 @@ jobs: strategy: fail-fast: false matrix: - shard: [1, 2] + shard: [1, 2, 3, 4, 5] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -256,7 +256,7 @@ jobs: name: E2E BlockSuite Cross Browser Test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -282,52 +282,6 @@ jobs: path: ./test-results if-no-files-found: ignore - bundler-matrix: - name: Bundler Matrix (${{ matrix.bundler }}) - runs-on: ubuntu-24.04-arm - strategy: - fail-fast: false - matrix: - bundler: [webpack, rspack] - steps: - - uses: actions/checkout@v4 - - name: Setup Node.js - uses: ./.github/actions/setup-node - with: - playwright-install: false - electron-install: false - full-cache: true - - - name: Run frontend build matrix - env: - AFFINE_BUNDLER: ${{ matrix.bundler }} - run: | - set -euo pipefail - packages=( - "@affine/web" - "@affine/mobile" - "@affine/ios" - "@affine/android" - "@affine/admin" - "@affine/electron-renderer" - ) - summary="test-results-bundler-${AFFINE_BUNDLER}.txt" - : > "$summary" - for pkg in "${packages[@]}"; do - start=$(date +%s) - yarn affine "$pkg" build - end=$(date +%s) - echo "${pkg},$((end-start))" >> "$summary" - done - - - name: Upload bundler timing - if: always() - uses: actions/upload-artifact@v4 - with: - name: test-results-bundler-${{ matrix.bundler }} - path: ./test-results-bundler-${{ matrix.bundler }}.txt - if-no-files-found: ignore - e2e-test: name: E2E Test runs-on: ubuntu-24.04-arm @@ -340,7 +294,7 @@ jobs: matrix: shard: [1, 2, 3, 4, 5] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -372,7 +326,7 @@ jobs: matrix: shard: [1, 2] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -402,9 +356,9 @@ jobs: strategy: fail-fast: false matrix: - shard: [1, 2, 3] + shard: [1, 2, 3, 4, 5] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -437,7 +391,7 @@ jobs: env: CARGO_PROFILE_RELEASE_DEBUG: '1' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -476,7 +430,7 @@ jobs: - { os: macos-latest, target: aarch64-apple-darwin } steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -517,7 +471,7 @@ jobs: - { os: windows-latest, target: aarch64-pc-windows-msvc } steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: samypr100/setup-dev-drive@v3 with: workspace-copy: true @@ -557,7 +511,7 @@ jobs: env: CARGO_PROFILE_RELEASE_DEBUG: '1' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -580,7 +534,7 @@ jobs: name: Build @affine/electron renderer runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -607,7 +561,7 @@ jobs: needs: - build-native-linux steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -661,7 +615,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -742,7 +696,7 @@ jobs: stack-version: 9.0.1 security-enabled: false - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -805,7 +759,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -846,7 +800,7 @@ jobs: CARGO_TERM_COLOR: always MIRIFLAGS: -Zmiri-backtrace=full -Zmiri-tree-borrows steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust uses: dtolnay/rust-toolchain@stable @@ -874,7 +828,7 @@ jobs: RUST_BACKTRACE: full CARGO_TERM_COLOR: always steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust uses: dtolnay/rust-toolchain@stable @@ -898,7 +852,7 @@ jobs: env: CARGO_TERM_COLOR: always steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust uses: dtolnay/rust-toolchain@stable @@ -937,7 +891,7 @@ jobs: env: CARGO_TERM_COLOR: always steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust uses: ./.github/actions/build-rust with: @@ -960,7 +914,7 @@ jobs: run-api: ${{ steps.decision.outputs.run_api }} run-e2e: ${{ steps.decision.outputs.run_e2e }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dorny/paths-filter@v3 id: copilot-filter @@ -1029,7 +983,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -1102,7 +1056,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -1185,7 +1139,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -1266,7 +1220,7 @@ jobs: test: true, } steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node timeout-minutes: 10 diff --git a/.github/workflows/copilot-test.yml b/.github/workflows/copilot-test.yml index 066ea81bda..24773fca17 100644 --- a/.github/workflows/copilot-test.yml +++ b/.github/workflows/copilot-test.yml @@ -10,7 +10,7 @@ jobs: env: CARGO_PROFILE_RELEASE_DEBUG: '1' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node with: @@ -64,7 +64,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -134,7 +134,7 @@ jobs: ports: - 9308:9308 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js uses: ./.github/actions/setup-node @@ -167,7 +167,7 @@ jobs: runs-on: ubuntu-latest name: Post test result message steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Setup Node.js diff --git a/.github/workflows/pr-title-lint.yml b/.github/workflows/pr-title-lint.yml index d8d9044dcb..dfaf04b7f2 100644 --- a/.github/workflows/pr-title-lint.yml +++ b/.github/workflows/pr-title-lint.yml @@ -18,9 +18,9 @@ jobs: runs-on: ubuntu-latest if: ${{ github.event.action != 'edited' || github.event.changes.title != null }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: cache: 'yarn' node-version-file: '.nvmrc' diff --git a/.github/workflows/release-cloud.yml b/.github/workflows/release-cloud.yml index 18b046e71a..dc9ea56fc7 100644 --- a/.github/workflows/release-cloud.yml +++ b/.github/workflows/release-cloud.yml @@ -35,7 +35,7 @@ jobs: - build-images runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Deploy to ${{ inputs.build-type }} uses: ./.github/actions/deploy with: diff --git a/.github/workflows/release-desktop-platform.yml b/.github/workflows/release-desktop-platform.yml index 91b8e761bd..a531f01431 100644 --- a/.github/workflows/release-desktop-platform.yml +++ b/.github/workflows/release-desktop-platform.yml @@ -69,7 +69,7 @@ jobs: SENTRY_DSN: ${{ secrets.SENTRY_DSN }} SENTRY_RELEASE: ${{ inputs.app_version }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version @@ -101,7 +101,7 @@ jobs: - name: Signing By Apple Developer ID if: ${{ inputs.platform == 'darwin' && inputs.apple_codesign }} - uses: apple-actions/import-codesign-certs@v5 + uses: apple-actions/import-codesign-certs@v6 with: p12-file-base64: ${{ secrets.CERTIFICATES_P12 }} p12-password: ${{ secrets.CERTIFICATES_P12_PASSWORD }} @@ -178,14 +178,14 @@ jobs: mv packages/frontend/apps/electron/out/*/make/deb/${{ inputs.arch }}/*.deb ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-linux-${{ inputs.arch }}.deb mv packages/frontend/apps/electron/out/*/make/flatpak/*/*.flatpak ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-linux-${{ inputs.arch }}.flatpak - - uses: actions/attest-build-provenance@v2 + - uses: actions/attest-build-provenance@v4 if: ${{ inputs.platform == 'darwin' }} with: subject-path: | ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-macos-${{ inputs.arch }}.zip ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-macos-${{ inputs.arch }}.dmg - - uses: actions/attest-build-provenance@v2 + - uses: actions/attest-build-provenance@v4 if: ${{ inputs.platform == 'linux' }} with: subject-path: | diff --git a/.github/workflows/release-desktop.yml b/.github/workflows/release-desktop.yml index 275265b1b3..e102c91bf3 100644 --- a/.github/workflows/release-desktop.yml +++ b/.github/workflows/release-desktop.yml @@ -48,7 +48,7 @@ jobs: runs-on: ubuntu-latest environment: ${{ inputs.build-type }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -187,7 +187,7 @@ jobs: FILES_TO_BE_SIGNED_x64: ${{ steps.get_files_to_be_signed.outputs.FILES_TO_BE_SIGNED_x64 }} FILES_TO_BE_SIGNED_arm64: ${{ steps.get_files_to_be_signed.outputs.FILES_TO_BE_SIGNED_arm64 }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -344,7 +344,7 @@ jobs: mv packages/frontend/apps/electron/out/*/make/squirrel.windows/${{ matrix.spec.arch }}/*.exe ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.exe mv packages/frontend/apps/electron/out/*/make/nsis.windows/${{ matrix.spec.arch }}/*.exe ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.nsis.exe - - uses: actions/attest-build-provenance@v2 + - uses: actions/attest-build-provenance@v4 with: subject-path: | ./builds/affine-${{ env.RELEASE_VERSION }}-${{ env.BUILD_TYPE }}-windows-${{ matrix.spec.arch }}.zip @@ -369,7 +369,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Download Artifacts (macos-x64) uses: actions/download-artifact@v4 with: @@ -395,7 +395,7 @@ jobs: with: name: affine-linux-x64-builds path: ./release - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v6 with: node-version: 20 - name: Copy Selfhost Release Files diff --git a/.github/workflows/release-mobile.yml b/.github/workflows/release-mobile.yml index cffad397c6..d04bcab76f 100644 --- a/.github/workflows/release-mobile.yml +++ b/.github/workflows/release-mobile.yml @@ -26,7 +26,7 @@ jobs: runs-on: ubuntu-latest environment: ${{ inputs.build-type }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -54,7 +54,7 @@ jobs: build-android-web: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -83,7 +83,7 @@ jobs: needs: - build-ios-web steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -114,7 +114,7 @@ jobs: - name: Cap sync run: yarn workspace @affine/ios sync - name: Signing By Apple Developer ID - uses: apple-actions/import-codesign-certs@v5 + uses: apple-actions/import-codesign-certs@v6 id: import-codesign-certs with: p12-file-base64: ${{ secrets.CERTIFICATES_P12_MOBILE }} @@ -147,7 +147,7 @@ jobs: needs: - build-android-web steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Version uses: ./.github/actions/setup-version with: @@ -180,7 +180,7 @@ jobs: no-build: 'true' - name: Cap sync run: yarn workspace @affine/android cap sync - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.13' - name: Auth gcloud @@ -192,7 +192,7 @@ jobs: token_format: 'access_token' project_id: '${{ secrets.GCP_PROJECT_ID }}' access_token_scopes: 'https://www.googleapis.com/auth/androidpublisher' - - uses: actions/setup-java@v4 + - uses: actions/setup-java@v5 with: distribution: 'temurin' java-version: '21' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 98c0db20f9..8e35d77f57 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -55,7 +55,7 @@ jobs: GIT_SHORT_HASH: ${{ steps.prepare.outputs.GIT_SHORT_HASH }} BUILD_TYPE: ${{ steps.prepare.outputs.BUILD_TYPE }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Prepare Release id: prepare uses: ./.github/actions/prepare-release @@ -72,7 +72,7 @@ jobs: steps: - name: Decide whether to release id: decide - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | const buildType = '${{ needs.prepare.outputs.BUILD_TYPE }}' diff --git a/.gitignore b/.gitignore index 15e44010e7..08d069b6f2 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,7 @@ testem.log /typings tsconfig.tsbuildinfo .context +/*.md # System Files .DS_Store diff --git a/.nvmrc b/.nvmrc index 85e502778f..32a2d7bd80 100644 --- a/.nvmrc +++ b/.nvmrc @@ -1 +1 @@ -22.22.0 +22.22.1 diff --git a/Cargo.lock b/Cargo.lock index cee519b41d..9593b0ccee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,9 +178,15 @@ name = "affine_server_native" version = "1.0.0" dependencies = [ "affine_common", + "anyhow", "chrono", "file-format", + "image", "infer", + "libwebp-sys", + "little_exif", + "llm_adapter", + "matroska", "mimalloc", "mp4parse", "napi", @@ -188,6 +194,8 @@ dependencies = [ "napi-derive", "rand 0.9.2", "rayon", + "serde", + "serde_json", "sha3", "tiktoken-rs", "tokio", @@ -232,6 +240,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -245,7 +268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43" dependencies = [ "alsa-sys", - "bitflags 2.10.0", + "bitflags 2.11.0", "cfg-if", "libc", ] @@ -533,7 +556,7 @@ version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "cexpr", "clang-sys", "itertools 0.13.0", @@ -583,9 +606,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" dependencies = [ "serde_core", ] @@ -599,6 +622,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "bitstream-io" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "680575de65ce8b916b82a447458b94a48776707d9c2681a9d8da351c06886a1f" +dependencies = [ + "core2", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -641,6 +673,27 @@ version = "0.9.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "473976d7a8620bb1e06dcdd184407c2363fe4fec8e983ee03ed9197222634a31" +[[package]] +name = "brotli" +version = "8.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bstr" version = "1.12.1" @@ -676,6 +729,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + [[package]] name = "bytes" version = "1.11.1" @@ -904,6 +963,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colorchoice" version = "1.0.4" @@ -983,7 +1048,7 @@ version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "core-foundation", "core-graphics-types", "foreign-types", @@ -996,7 +1061,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "064badf302c3194842cf2c5d61f56cc88e54a759313879cdf03abdd27d0c3b97" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "core-foundation", "core-graphics-types", "foreign-types", @@ -1009,7 +1074,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "core-foundation", "libc", ] @@ -1379,7 +1444,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "block2", "libc", "objc2", @@ -1562,6 +1627,15 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + [[package]] name = "file-format" version = "0.28.0" @@ -1833,6 +1907,16 @@ dependencies = [ "wasip2", ] +[[package]] +name = "gif" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5df2ba84018d80c213569363bdcd0c64e6933c67fe4c1d60ecf822971a3c35e" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "glob" version = "0.3.3" @@ -1880,7 +1964,7 @@ version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0c43e7c3212bd992c11b6b9796563388170950521ae8487f5cdf6f6e792f1c8" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "proc-macro2", "quote", "syn 1.0.109", @@ -2003,6 +2087,22 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -2175,6 +2275,34 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "image" +version = "0.25.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error 2.0.1", +] + [[package]] name = "include-flate" version = "0.3.1" @@ -2376,9 +2504,9 @@ dependencies = [ [[package]] name = "keccak" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" dependencies = [ "cpufeatures", ] @@ -2490,7 +2618,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "libc", "redox_syscall 0.7.0", ] @@ -2506,6 +2634,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libwebp-sys" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375ca3fbd6d89769361c5d505c9da676eb4128ee471b9fd763144d377a2d30e6" +dependencies = [ + "cc", + "glob", + "pkg-config", +] + [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -2518,6 +2657,33 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "little_exif" +version = "0.6.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21eeb58b22d31be8dc5c625004fcd4b9b385cd3c05df575f523bcca382c51122" +dependencies = [ + "brotli", + "crc", + "log", + "miniz_oxide", + "paste", + "quick-xml", +] + +[[package]] +name = "llm_adapter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e98485dda5180cc89b993a001688bed93307be6bd8fedcde445b69bbca4f554d" +dependencies = [ + "base64", + "serde", + "serde_json", + "thiserror 2.0.17", + "ureq", +] + [[package]] name = "lock_api" version = "0.4.14" @@ -2555,7 +2721,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59fa2559e99ba0f26a12458aabc754432c805bbb8cba516c427825a997af1fb7" dependencies = [ "aes", - "bitflags 2.10.0", + "bitflags 2.11.0", "cbc", "ecb", "encoding_rs", @@ -2642,6 +2808,16 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matroska" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde85cd7fb5cf875c4a46fac0cbd6567d413bea2538cef6788e3a0e52a902b45" +dependencies = [ + "bitstream-io", + "phf 0.11.3", +] + [[package]] name = "md-5" version = "0.10.6" @@ -2711,6 +2887,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "moxcms" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac9557c559cd6fc9867e122e20d2cbefc9ca29d80d027a8e39310920ed2f0a97" +dependencies = [ + "num-traits", + "pxfm", +] + [[package]] name = "mp4parse" version = "0.17.0" @@ -2741,7 +2927,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "000f205daae6646003fdc38517be6232af2b150bad4b67bdaf4c5aadb119d738" dependencies = [ "anyhow", - "bitflags 2.10.0", + "bitflags 2.11.0", "chrono", "ctor", "futures", @@ -2801,7 +2987,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "jni-sys", "log", "ndk-sys", @@ -2836,7 +3022,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "cfg-if", "cfg_aliases", "libc", @@ -3000,7 +3186,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "dispatch2", "objc2", ] @@ -3017,7 +3203,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "block2", "libc", "objc2", @@ -3123,6 +3309,12 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "path-ext" version = "0.1.2" @@ -3369,6 +3561,19 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags 2.11.0", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "pom" version = "1.1.0" @@ -3469,7 +3674,7 @@ checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" dependencies = [ "bit-set 0.8.0", "bit-vec 0.8.0", - "bitflags 2.10.0", + "bitflags 2.11.0", "num-traits", "rand 0.9.2", "rand_chacha 0.9.0", @@ -3497,7 +3702,7 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e8bbe1a966bd2f362681a44f6edce3c2310ac21e4d5067a6e7ec396297a6ea0" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "getopts", "memchr", "pulldown-cmark-escape", @@ -3510,12 +3715,33 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae" +[[package]] +name = "pxfm" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5a041e753da8b807c9255f28de81879c78c876392ff2469cde94799b2896b9d" + [[package]] name = "quick-error" version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + +[[package]] +name = "quick-xml" +version = "0.37.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" +dependencies = [ + "memchr", +] + [[package]] name = "quote" version = "1.0.43" @@ -3663,7 +3889,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", ] [[package]] @@ -3672,7 +3898,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", ] [[package]] @@ -3831,7 +4057,7 @@ version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ - "bitflags 2.10.0", + "bitflags 2.11.0", "errno", "libc", "linux-raw-sys", @@ -3844,6 +4070,7 @@ version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3885,7 +4112,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" dependencies = [ "fnv", - "quick-error", + "quick-error 1.2.3", "tempfile", "wait-timeout", ] @@ -4269,7 +4496,7 @@ checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526" dependencies = [ "atoi", "base64", - "bitflags 2.10.0", + "bitflags 2.11.0", "byteorder", "bytes", "chrono", @@ -4312,7 +4539,7 @@ checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" dependencies = [ "atoi", "base64", - "bitflags 2.10.0", + "bitflags 2.11.0", "byteorder", "chrono", "crc", @@ -4912,9 +5139,9 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.6+spec-1.1.0" +version = "1.0.9+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" dependencies = [ "winnow", ] @@ -5345,6 +5572,35 @@ version = "0.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" +[[package]] +name = "ureq" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" +dependencies = [ + "base64", + "flate2", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "ureq-proto", + "utf-8", + "webpki-roots 1.0.5", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -6310,3 +6566,18 @@ dependencies = [ "cc", "pkg-config", ] + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-jpeg" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "410e9ecef634c709e3831c2cfdb8d9c32164fae1c67496d5b68fff728eec37fe" +dependencies = [ + "zune-core", +] diff --git a/Cargo.toml b/Cargo.toml index c64a297349..a547b0f89a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,13 +40,24 @@ resolver = "3" dotenvy = "0.15" file-format = { version = "0.28", features = ["reader"] } homedir = "0.3" + image = { version = "0.25.9", default-features = false, features = [ + "bmp", + "gif", + "jpeg", + "png", + "webp", + ] } infer = { version = "0.19.0" } lasso = { version = "0.7", features = ["multi-threaded"] } lib0 = { version = "0.16", features = ["lib0-serde"] } libc = "0.2" + libwebp-sys = "0.14.2" + little_exif = "0.6.23" + llm_adapter = { version = "0.1.3", default-features = false } log = "0.4" loom = { version = "0.7", features = ["checkpoint"] } lru = "0.16" + matroska = "0.30" memory-indexer = "0.3.0" mimalloc = "0.1" mp4parse = "0.17" diff --git a/SECURITY.md b/SECURITY.md index 374efb5aed..f493f19d58 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -23,4 +23,6 @@ We welcome you to provide us with bug reports via and email at [security@toevery Since we are an open source project, we also welcome you to provide corresponding fix PRs, we will determine specific rewards based on the evaluation results. +Due to limited resources, we do not accept and will not review any AI-generated security reports. + If the vulnerability is caused by a library we depend on, we encourage you to submit a security report to the corresponding dependent library at the same time to benefit more users. diff --git a/blocksuite/affine/all/package.json b/blocksuite/affine/all/package.json index abba782d27..3494b132cb 100644 --- a/blocksuite/affine/all/package.json +++ b/blocksuite/affine/all/package.json @@ -300,6 +300,6 @@ "devDependencies": { "@vanilla-extract/vite-plugin": "^5.0.0", "msw": "^2.12.4", - "vitest": "^3.2.4" + "vitest": "^4.0.18" } } diff --git a/blocksuite/affine/all/vitest.config.ts b/blocksuite/affine/all/vitest.config.ts index c2625c985b..ce1fac3269 100644 --- a/blocksuite/affine/all/vitest.config.ts +++ b/blocksuite/affine/all/vitest.config.ts @@ -11,7 +11,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 1000, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/blocksuite-affine', }, diff --git a/blocksuite/affine/blocks/bookmark/package.json b/blocksuite/affine/blocks/bookmark/package.json index b213714cfc..334a0705b1 100644 --- a/blocksuite/affine/blocks/bookmark/package.json +++ b/blocksuite/affine/blocks/bookmark/package.json @@ -31,7 +31,8 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "@vitest/browser-playwright": "^4.0.18", + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/blocks/bookmark/src/bookmark-block.ts b/blocksuite/affine/blocks/bookmark/src/bookmark-block.ts index a591397ff0..8d944a1050 100644 --- a/blocksuite/affine/blocks/bookmark/src/bookmark-block.ts +++ b/blocksuite/affine/blocks/bookmark/src/bookmark-block.ts @@ -108,7 +108,9 @@ export class BookmarkBlockComponent extends CaptionedBlockComponent { - window.open(this.link, '_blank'); + const link = this.link; + if (!link) return; + window.open(link, '_blank', 'noopener,noreferrer'); }; refreshData = () => { diff --git a/blocksuite/affine/blocks/bookmark/vitest.config.ts b/blocksuite/affine/blocks/bookmark/vitest.config.ts index 255be530ba..1752e9d318 100644 --- a/blocksuite/affine/blocks/bookmark/vitest.config.ts +++ b/blocksuite/affine/blocks/bookmark/vitest.config.ts @@ -1,3 +1,4 @@ +import { playwright } from '@vitest/browser-playwright'; import { defineConfig } from 'vitest/config'; export default defineConfig({ @@ -8,10 +9,9 @@ export default defineConfig({ browser: { enabled: true, headless: true, - name: 'chromium', - provider: 'playwright', + instances: [{ browser: 'chromium' }], + provider: playwright(), isolate: false, - providerOptions: {}, }, include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, diff --git a/blocksuite/affine/blocks/code/src/highlight/affine-code-unit.ts b/blocksuite/affine/blocks/code/src/highlight/affine-code-unit.ts index 1fff01008c..f700c2e40c 100644 --- a/blocksuite/affine/blocks/code/src/highlight/affine-code-unit.ts +++ b/blocksuite/affine/blocks/code/src/highlight/affine-code-unit.ts @@ -45,8 +45,10 @@ export class AffineCodeUnit extends ShadowlessElement { if (!codeBlock || !vElement) return plainContent; const tokens = codeBlock.highlightTokens$.value; if (tokens.length === 0) return plainContent; + const line = tokens[vElement.lineIndex]; + if (!line) return plainContent; // copy the tokens to avoid modifying the original tokens - const lineTokens = structuredClone(tokens[vElement.lineIndex]); + const lineTokens = structuredClone(line); if (lineTokens.length === 0) return plainContent; const startOffset = vElement.startOffset; diff --git a/blocksuite/affine/blocks/embed-doc/package.json b/blocksuite/affine/blocks/embed-doc/package.json index bc9a89e7ec..71bbf4d3a0 100644 --- a/blocksuite/affine/blocks/embed-doc/package.json +++ b/blocksuite/affine/blocks/embed-doc/package.json @@ -35,7 +35,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/blocks/embed/package.json b/blocksuite/affine/blocks/embed/package.json index a22156072f..5ee9d71d17 100644 --- a/blocksuite/affine/blocks/embed/package.json +++ b/blocksuite/affine/blocks/embed/package.json @@ -35,7 +35,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/blocks/list/package.json b/blocksuite/affine/blocks/list/package.json index 00bc801560..ab21e70b5e 100644 --- a/blocksuite/affine/blocks/list/package.json +++ b/blocksuite/affine/blocks/list/package.json @@ -31,7 +31,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/blocks/note/src/note-edgeless-block.ts b/blocksuite/affine/blocks/note/src/note-edgeless-block.ts index bc2829e7a1..639bd0a185 100644 --- a/blocksuite/affine/blocks/note/src/note-edgeless-block.ts +++ b/blocksuite/affine/blocks/note/src/note-edgeless-block.ts @@ -221,6 +221,12 @@ export class EdgelessNoteBlockComponent extends toGfxBlockComponent( } } + override getCSSScaleVal(): number { + const baseScale = super.getCSSScaleVal(); + const extraScale = this.model.props.edgeless?.scale ?? 1; + return baseScale * extraScale; + } + override getRenderingRect() { const { xywh, edgeless } = this.model.props; const { collapse, scale = 1 } = edgeless; @@ -255,7 +261,6 @@ export class EdgelessNoteBlockComponent extends toGfxBlockComponent( const style = { borderRadius: borderRadius + 'px', - transform: `scale(${scale})`, }; const extra = this._editing ? ACTIVE_NOTE_EXTRA_PADDING : 0; @@ -454,6 +459,28 @@ export const EdgelessNoteInteraction = return; } + let isClickOnTitle = false; + const titleRect = view + .querySelector('edgeless-page-block-title') + ?.getBoundingClientRect(); + + if (titleRect) { + const titleBound = new Bound( + titleRect.x, + titleRect.y, + titleRect.width, + titleRect.height + ); + if (titleBound.isPointInBound([e.clientX, e.clientY])) { + isClickOnTitle = true; + } + } + + if (isClickOnTitle) { + handleNativeRangeAtPoint(e.clientX, e.clientY); + return; + } + if (model.children.length === 0) { const blockId = std.store.addBlock( 'affine:paragraph', diff --git a/blocksuite/affine/blocks/root/src/edgeless/configs/toolbar/more.ts b/blocksuite/affine/blocks/root/src/edgeless/configs/toolbar/more.ts index 6374cf8f08..d4a3651d1f 100644 --- a/blocksuite/affine/blocks/root/src/edgeless/configs/toolbar/more.ts +++ b/blocksuite/affine/blocks/root/src/edgeless/configs/toolbar/more.ts @@ -22,6 +22,7 @@ import { FrameBlockModel, ImageBlockModel, isExternalEmbedModel, + MindmapElementModel, NoteBlockModel, ParagraphBlockModel, } from '@blocksuite/affine-model'; @@ -401,7 +402,17 @@ function reorderElements( ) { if (!models.length) return; - for (const model of models) { + const normalizedModels = Array.from( + new Map( + models.map(model => { + const reorderTarget = + model.group instanceof MindmapElementModel ? model.group : model; + return [reorderTarget.id, reorderTarget]; + }) + ).values() + ); + + for (const model of normalizedModels) { const index = ctx.gfx.layer.getReorderedIndex(model, type); // block should be updated in transaction diff --git a/blocksuite/affine/blocks/surface/package.json b/blocksuite/affine/blocks/surface/package.json index 5b1114acde..3b397799d9 100644 --- a/blocksuite/affine/blocks/surface/package.json +++ b/blocksuite/affine/blocks/surface/package.json @@ -33,7 +33,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/blocks/surface/src/renderer/canvas-renderer.ts b/blocksuite/affine/blocks/surface/src/renderer/canvas-renderer.ts index 86f53a7398..04c8426255 100644 --- a/blocksuite/affine/blocks/surface/src/renderer/canvas-renderer.ts +++ b/blocksuite/affine/blocks/surface/src/renderer/canvas-renderer.ts @@ -2,16 +2,24 @@ import { type Color, ColorScheme } from '@blocksuite/affine-model'; import { FeatureFlagService } from '@blocksuite/affine-shared/services'; import { requestConnectedFrame } from '@blocksuite/affine-shared/utils'; import { DisposableGroup } from '@blocksuite/global/disposable'; -import type { IBound } from '@blocksuite/global/gfx'; -import { getBoundWithRotation, intersects } from '@blocksuite/global/gfx'; +import { + Bound, + getBoundWithRotation, + type IBound, + intersects, +} from '@blocksuite/global/gfx'; import type { BlockStdScope } from '@blocksuite/std'; import type { GfxCompatibleInterface, + GfxController, + GfxLocalElementModel, GridManager, LayerManager, SurfaceBlockModel, Viewport, } from '@blocksuite/std/gfx'; +import { GfxControllerIdentifier } from '@blocksuite/std/gfx'; +import { effect } from '@preact/signals-core'; import last from 'lodash-es/last'; import { Subject } from 'rxjs'; @@ -40,11 +48,82 @@ type RendererOptions = { surfaceModel: SurfaceBlockModel; }; +export type CanvasRenderPassMetrics = { + overlayCount: number; + placeholderElementCount: number; + renderByBoundCallCount: number; + renderedElementCount: number; + visibleElementCount: number; +}; + +export type CanvasMemorySnapshot = { + bytes: number; + datasetLayerId: string | null; + height: number; + kind: 'main' | 'stacking'; + width: number; + zIndex: string; +}; + +export type CanvasRendererDebugMetrics = { + canvasLayerCount: number; + canvasMemoryBytes: number; + canvasMemorySnapshots: CanvasMemorySnapshot[]; + canvasMemoryMegabytes: number; + canvasPixelCount: number; + coalescedRefreshCount: number; + dirtyLayerRenderCount: number; + fallbackElementCount: number; + lastRenderDurationMs: number; + lastRenderMetrics: CanvasRenderPassMetrics; + maxRenderDurationMs: number; + pooledStackingCanvasCount: number; + refreshCount: number; + renderCount: number; + stackingCanvasCount: number; + totalLayerCount: number; + totalRenderDurationMs: number; + visibleStackingCanvasCount: number; +}; + +type MutableCanvasRendererDebugMetrics = Omit< + CanvasRendererDebugMetrics, + | 'canvasLayerCount' + | 'canvasMemoryBytes' + | 'canvasMemoryMegabytes' + | 'canvasPixelCount' + | 'canvasMemorySnapshots' + | 'pooledStackingCanvasCount' + | 'stackingCanvasCount' + | 'totalLayerCount' + | 'visibleStackingCanvasCount' +>; + +type RenderPassStats = CanvasRenderPassMetrics; + +type StackingCanvasState = { + bound: Bound | null; + layerId: string | null; +}; + +type RefreshTarget = + | { type: 'all' } + | { type: 'main' } + | { type: 'element'; element: SurfaceElementModel | GfxLocalElementModel } + | { + type: 'elements'; + elements: Array; + }; + +const STACKING_CANVAS_PADDING = 32; + export class CanvasRenderer { private _container!: HTMLElement; private readonly _disposables = new DisposableGroup(); + private readonly _gfx: GfxController; + private readonly _turboEnabled: () => boolean; private readonly _overlays = new Set(); @@ -53,6 +132,37 @@ export class CanvasRenderer { private _stackingCanvas: HTMLCanvasElement[] = []; + private readonly _stackingCanvasPool: HTMLCanvasElement[] = []; + + private readonly _stackingCanvasState = new WeakMap< + HTMLCanvasElement, + StackingCanvasState + >(); + + private readonly _dirtyStackingCanvasIndexes = new Set(); + + private _mainCanvasDirty = true; + + private _needsFullRender = true; + + private _debugMetrics: MutableCanvasRendererDebugMetrics = { + refreshCount: 0, + coalescedRefreshCount: 0, + renderCount: 0, + totalRenderDurationMs: 0, + lastRenderDurationMs: 0, + maxRenderDurationMs: 0, + lastRenderMetrics: { + renderByBoundCallCount: 0, + visibleElementCount: 0, + renderedElementCount: 0, + placeholderElementCount: 0, + overlayCount: 0, + }, + dirtyLayerRenderCount: 0, + fallbackElementCount: 0, + }; + canvas: HTMLCanvasElement; ctx: CanvasRenderingContext2D; @@ -89,6 +199,7 @@ export class CanvasRenderer { this.layerManager = options.layerManager; this.grid = options.gridManager; this.provider = options.provider ?? {}; + this._gfx = this.std.get(GfxControllerIdentifier); this._turboEnabled = () => { const featureFlagService = options.std.get(FeatureFlagService); @@ -132,15 +243,199 @@ export class CanvasRenderer { }; } + private _applyStackingCanvasLayout( + canvas: HTMLCanvasElement, + bound: Bound | null, + dpr = window.devicePixelRatio + ) { + const state = + this._stackingCanvasState.get(canvas) ?? + ({ + bound: null, + layerId: canvas.dataset.layerId ?? null, + } satisfies StackingCanvasState); + + if (!bound || bound.w <= 0 || bound.h <= 0) { + canvas.style.display = 'none'; + canvas.style.left = '0px'; + canvas.style.top = '0px'; + canvas.style.width = '0px'; + canvas.style.height = '0px'; + canvas.style.transform = ''; + canvas.width = 0; + canvas.height = 0; + state.bound = null; + state.layerId = canvas.dataset.layerId ?? null; + this._stackingCanvasState.set(canvas, state); + return; + } + + const { viewportBounds, zoom, viewScale } = this.viewport; + const width = bound.w * zoom; + const height = bound.h * zoom; + const left = (bound.x - viewportBounds.x) * zoom; + const top = (bound.y - viewportBounds.y) * zoom; + const actualWidth = Math.max(1, Math.ceil(width * dpr)); + const actualHeight = Math.max(1, Math.ceil(height * dpr)); + const transform = `translate(${left}px, ${top}px) scale(${1 / viewScale})`; + + if (canvas.style.display !== 'block') { + canvas.style.display = 'block'; + } + if (canvas.style.left !== '0px') { + canvas.style.left = '0px'; + } + if (canvas.style.top !== '0px') { + canvas.style.top = '0px'; + } + if (canvas.style.width !== `${width}px`) { + canvas.style.width = `${width}px`; + } + if (canvas.style.height !== `${height}px`) { + canvas.style.height = `${height}px`; + } + if (canvas.style.transform !== transform) { + canvas.style.transform = transform; + } + if (canvas.style.transformOrigin !== 'top left') { + canvas.style.transformOrigin = 'top left'; + } + + if (canvas.width !== actualWidth) { + canvas.width = actualWidth; + } + + if (canvas.height !== actualHeight) { + canvas.height = actualHeight; + } + + state.bound = bound; + state.layerId = canvas.dataset.layerId ?? null; + this._stackingCanvasState.set(canvas, state); + } + + private _clampBoundToViewport(bound: Bound, viewportBounds: Bound) { + const minX = Math.max(bound.x, viewportBounds.x); + const minY = Math.max(bound.y, viewportBounds.y); + const maxX = Math.min(bound.maxX, viewportBounds.maxX); + const maxY = Math.min(bound.maxY, viewportBounds.maxY); + + if (maxX <= minX || maxY <= minY) { + return null; + } + + return new Bound(minX, minY, maxX - minX, maxY - minY); + } + + private _createCanvasForLayer( + onCreated?: (canvas: HTMLCanvasElement) => void + ) { + const reused = this._stackingCanvasPool.pop(); + + if (reused) { + return reused; + } + + const created = document.createElement('canvas'); + onCreated?.(created); + return created; + } + + private _findLayerIndexByElement( + element: SurfaceElementModel | GfxLocalElementModel + ) { + const canvasLayers = this.layerManager.getCanvasLayers(); + const index = canvasLayers.findIndex(layer => + layer.elements.some(layerElement => layerElement.id === element.id) + ); + + return index === -1 ? null : index; + } + + private _getLayerRenderBound( + elements: SurfaceElementModel[], + viewportBounds: Bound + ) { + let layerBound: Bound | null = null; + + for (const element of elements) { + const display = (element.display ?? true) && !element.hidden; + + if (!display) { + continue; + } + + const elementBound = Bound.from(getBoundWithRotation(element)); + + if (!intersects(elementBound, viewportBounds)) { + continue; + } + + layerBound = layerBound ? layerBound.unite(elementBound) : elementBound; + } + + if (!layerBound) { + return null; + } + + return this._clampBoundToViewport( + layerBound.expand(STACKING_CANVAS_PADDING), + viewportBounds + ); + } + + private _getResolvedStackingCanvasBound( + canvas: HTMLCanvasElement, + bound: Bound | null + ) { + if (!bound || !this._gfx.tool.dragging$.peek()) { + return bound; + } + + const previousBound = this._stackingCanvasState.get(canvas)?.bound; + + return previousBound ? previousBound.unite(bound) : bound; + } + + private _invalidate(target: RefreshTarget = { type: 'all' }) { + if (target.type === 'all') { + this._needsFullRender = true; + this._mainCanvasDirty = true; + this._dirtyStackingCanvasIndexes.clear(); + return; + } + + if (this._needsFullRender) { + return; + } + + if (target.type === 'main') { + this._mainCanvasDirty = true; + return; + } + + const elements = + target.type === 'element' ? [target.element] : target.elements; + + for (const element of elements) { + const layerIndex = this._findLayerIndexByElement(element); + + if (layerIndex === null || layerIndex >= this._stackingCanvas.length) { + this._mainCanvasDirty = true; + continue; + } + + this._dirtyStackingCanvasIndexes.add(layerIndex); + } + } + + private _resetPooledCanvas(canvas: HTMLCanvasElement) { + canvas.dataset.layerId = ''; + this._applyStackingCanvasLayout(canvas, null); + } + private _initStackingCanvas(onCreated?: (canvas: HTMLCanvasElement) => void) { const layer = this.layerManager; - const updateStackingCanvasSize = (canvases: HTMLCanvasElement[]) => { - this._stackingCanvas = canvases; - - const sizeUpdater = this._canvasSizeUpdater(); - - canvases.filter(sizeUpdater.filter).forEach(sizeUpdater.update); - }; const updateStackingCanvas = () => { /** * we already have a main canvas, so the last layer should be skipped @@ -159,11 +454,7 @@ export class CanvasRenderer { const created = i < currentCanvases.length; const canvas = created ? currentCanvases[i] - : document.createElement('canvas'); - - if (!created) { - onCreated?.(canvas); - } + : this._createCanvasForLayer(onCreated); canvas.dataset.layerId = `[${layer.indexes[0]}--${layer.indexes[1]}]`; canvas.style.zIndex = layer.zIndex.toString(); @@ -171,7 +462,6 @@ export class CanvasRenderer { } this._stackingCanvas = canvases; - updateStackingCanvasSize(canvases); if (currentCanvases.length !== canvases.length) { const diff = canvases.length - currentCanvases.length; @@ -189,12 +479,16 @@ export class CanvasRenderer { payload.added = canvases.slice(-diff); } else { payload.removed = currentCanvases.slice(diff); + payload.removed.forEach(canvas => { + this._resetPooledCanvas(canvas); + this._stackingCanvasPool.push(canvas); + }); } this.stackingCanvasUpdated.next(payload); } - this.refresh(); + this.refresh({ type: 'all' }); }; this._disposables.add( @@ -211,7 +505,7 @@ export class CanvasRenderer { this._disposables.add( this.viewport.viewportUpdated.subscribe(() => { - this.refresh(); + this.refresh({ type: 'all' }); }) ); @@ -222,7 +516,6 @@ export class CanvasRenderer { sizeUpdatedRafId = null; this._resetSize(); this._render(); - this.refresh(); }, this._container); }) ); @@ -233,69 +526,212 @@ export class CanvasRenderer { if (this.usePlaceholder !== shouldRenderPlaceholders) { this.usePlaceholder = shouldRenderPlaceholders; - this.refresh(); + this.refresh({ type: 'all' }); } }) ); + let wasDragging = false; + this._disposables.add( + effect(() => { + const isDragging = this._gfx.tool.dragging$.value; + + if (wasDragging && !isDragging) { + this.refresh({ type: 'all' }); + } + + wasDragging = isDragging; + }) + ); + this.usePlaceholder = false; } + private _createRenderPassStats(): RenderPassStats { + return { + renderByBoundCallCount: 0, + visibleElementCount: 0, + renderedElementCount: 0, + placeholderElementCount: 0, + overlayCount: 0, + }; + } + + private _getCanvasMemorySnapshots(): CanvasMemorySnapshot[] { + return [this.canvas, ...this._stackingCanvas].map((canvas, index) => { + return { + kind: index === 0 ? 'main' : 'stacking', + width: canvas.width, + height: canvas.height, + bytes: canvas.width * canvas.height * 4, + zIndex: canvas.style.zIndex, + datasetLayerId: canvas.dataset.layerId ?? null, + }; + }); + } + private _render() { + const renderStart = performance.now(); const { viewportBounds, zoom } = this.viewport; const { ctx } = this; const dpr = window.devicePixelRatio; const scale = zoom * dpr; const matrix = new DOMMatrix().scaleSelf(scale); + const renderStats = this._createRenderPassStats(); + const fullRender = this._needsFullRender; + const stackingIndexesToRender = fullRender + ? this._stackingCanvas.map((_, idx) => idx) + : [...this._dirtyStackingCanvasIndexes]; /** * if a layer does not have a corresponding canvas * its element will be add to this array and drawing on the * main canvas */ let fallbackElement: SurfaceElementModel[] = []; + const allCanvasLayers = this.layerManager.getCanvasLayers(); + const viewportBound = Bound.from(viewportBounds); - this.layerManager.getCanvasLayers().forEach((layer, idx) => { - if (!this._stackingCanvas[idx]) { - fallbackElement = fallbackElement.concat(layer.elements); - return; + for (const idx of stackingIndexesToRender) { + const layer = allCanvasLayers[idx]; + const canvas = this._stackingCanvas[idx]; + + if (!layer || !canvas) { + continue; } - const canvas = this._stackingCanvas[idx]; - const ctx = canvas.getContext('2d') as CanvasRenderingContext2D; - const rc = new RoughCanvas(ctx.canvas); + const layerRenderBound = this._getLayerRenderBound( + layer.elements, + viewportBound + ); + const resolvedLayerRenderBound = this._getResolvedStackingCanvasBound( + canvas, + layerRenderBound + ); - ctx.clearRect(0, 0, canvas.width, canvas.height); + this._applyStackingCanvasLayout(canvas, resolvedLayerRenderBound); + + if ( + !resolvedLayerRenderBound || + canvas.width === 0 || + canvas.height === 0 + ) { + continue; + } + + const layerCtx = canvas.getContext('2d') as CanvasRenderingContext2D; + const layerRc = new RoughCanvas(layerCtx.canvas); + + layerCtx.clearRect(0, 0, canvas.width, canvas.height); + layerCtx.save(); + layerCtx.setTransform(matrix); + + this._renderByBound( + layerCtx, + matrix, + layerRc, + resolvedLayerRenderBound, + layer.elements, + false, + renderStats + ); + } + + if (fullRender || this._mainCanvasDirty) { + allCanvasLayers.forEach((layer, idx) => { + if (!this._stackingCanvas[idx]) { + fallbackElement = fallbackElement.concat(layer.elements); + } + }); + + ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); ctx.save(); ctx.setTransform(matrix); - this._renderByBound(ctx, matrix, rc, viewportBounds, layer.elements); - }); + this._renderByBound( + ctx, + matrix, + new RoughCanvas(ctx.canvas), + viewportBounds, + fallbackElement, + true, + renderStats + ); + } - ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); - ctx.save(); - - ctx.setTransform(matrix); - - this._renderByBound( - ctx, - matrix, - new RoughCanvas(ctx.canvas), - viewportBounds, - fallbackElement, - true + const canvasMemorySnapshots = this._getCanvasMemorySnapshots(); + const canvasMemoryBytes = canvasMemorySnapshots.reduce( + (sum, snapshot) => sum + snapshot.bytes, + 0 ); + const layerTypes = this.layerManager.layers.map(layer => layer.type); + const renderDurationMs = performance.now() - renderStart; + + this._debugMetrics.renderCount += 1; + this._debugMetrics.totalRenderDurationMs += renderDurationMs; + this._debugMetrics.lastRenderDurationMs = renderDurationMs; + this._debugMetrics.maxRenderDurationMs = Math.max( + this._debugMetrics.maxRenderDurationMs, + renderDurationMs + ); + this._debugMetrics.lastRenderMetrics = renderStats; + this._debugMetrics.fallbackElementCount = fallbackElement.length; + this._debugMetrics.dirtyLayerRenderCount = stackingIndexesToRender.length; + + this._lastDebugSnapshot = { + canvasMemorySnapshots, + canvasMemoryBytes, + canvasPixelCount: canvasMemorySnapshots.reduce( + (sum, snapshot) => sum + snapshot.width * snapshot.height, + 0 + ), + stackingCanvasCount: this._stackingCanvas.length, + canvasLayerCount: layerTypes.filter(type => type === 'canvas').length, + totalLayerCount: layerTypes.length, + pooledStackingCanvasCount: this._stackingCanvasPool.length, + visibleStackingCanvasCount: this._stackingCanvas.filter( + canvas => canvas.width > 0 && canvas.height > 0 + ).length, + }; + + this._needsFullRender = false; + this._mainCanvasDirty = false; + this._dirtyStackingCanvasIndexes.clear(); } + private _lastDebugSnapshot: Pick< + CanvasRendererDebugMetrics, + | 'canvasMemoryBytes' + | 'canvasMemorySnapshots' + | 'canvasPixelCount' + | 'canvasLayerCount' + | 'pooledStackingCanvasCount' + | 'stackingCanvasCount' + | 'totalLayerCount' + | 'visibleStackingCanvasCount' + > = { + canvasMemoryBytes: 0, + canvasMemorySnapshots: [], + canvasPixelCount: 0, + canvasLayerCount: 0, + pooledStackingCanvasCount: 0, + stackingCanvasCount: 0, + totalLayerCount: 0, + visibleStackingCanvasCount: 0, + }; + private _renderByBound( ctx: CanvasRenderingContext2D | null, matrix: DOMMatrix, rc: RoughCanvas, bound: IBound, surfaceElements?: SurfaceElementModel[], - overLay: boolean = false + overLay: boolean = false, + renderStats?: RenderPassStats ) { if (!ctx) return; + renderStats && (renderStats.renderByBoundCallCount += 1); + const elements = surfaceElements ?? (this.grid.search(bound, { @@ -305,10 +741,12 @@ export class CanvasRenderer { for (const element of elements) { const display = (element.display ?? true) && !element.hidden; if (display && intersects(getBoundWithRotation(element), bound)) { + renderStats && (renderStats.visibleElementCount += 1); if ( this.usePlaceholder && !(element as GfxCompatibleInterface).forceFullRender ) { + renderStats && (renderStats.placeholderElementCount += 1); ctx.save(); ctx.fillStyle = 'rgba(200, 200, 200, 0.5)'; const drawX = element.x - bound.x; @@ -316,6 +754,7 @@ export class CanvasRenderer { ctx.fillRect(drawX, drawY, element.w, element.h); ctx.restore(); } else { + renderStats && (renderStats.renderedElementCount += 1); ctx.save(); const renderFn = this.std.getOptional( ElementRendererIdentifier(element.type) @@ -333,6 +772,7 @@ export class CanvasRenderer { } if (overLay) { + renderStats && (renderStats.overlayCount += this._overlays.size); for (const overlay of this._overlays) { ctx.save(); ctx.translate(-bound.x, -bound.y); @@ -348,33 +788,38 @@ export class CanvasRenderer { const sizeUpdater = this._canvasSizeUpdater(); sizeUpdater.update(this.canvas); - - this._stackingCanvas.forEach(sizeUpdater.update); - this.refresh(); + this._invalidate({ type: 'all' }); } private _watchSurface(surfaceModel: SurfaceBlockModel) { this._disposables.add( - surfaceModel.elementAdded.subscribe(() => this.refresh()) + surfaceModel.elementAdded.subscribe(() => this.refresh({ type: 'all' })) ); this._disposables.add( - surfaceModel.elementRemoved.subscribe(() => this.refresh()) + surfaceModel.elementRemoved.subscribe(() => this.refresh({ type: 'all' })) ); this._disposables.add( - surfaceModel.localElementAdded.subscribe(() => this.refresh()) + surfaceModel.localElementAdded.subscribe(() => + this.refresh({ type: 'all' }) + ) ); this._disposables.add( - surfaceModel.localElementDeleted.subscribe(() => this.refresh()) + surfaceModel.localElementDeleted.subscribe(() => + this.refresh({ type: 'all' }) + ) ); this._disposables.add( - surfaceModel.localElementUpdated.subscribe(() => this.refresh()) + surfaceModel.localElementUpdated.subscribe(({ model }) => { + this.refresh({ type: 'element', element: model }); + }) ); this._disposables.add( surfaceModel.elementUpdated.subscribe(payload => { // ignore externalXYWH update cause it's updated by the renderer if (payload.props['externalXYWH']) return; - this.refresh(); + const element = surfaceModel.getElementById(payload.id); + this.refresh(element ? { type: 'element', element } : { type: 'all' }); }) ); } @@ -382,7 +827,7 @@ export class CanvasRenderer { addOverlay(overlay: Overlay) { overlay.setRenderer(this); this._overlays.add(overlay); - this.refresh(); + this.refresh({ type: 'main' }); } /** @@ -394,7 +839,7 @@ export class CanvasRenderer { container.append(this.canvas); this._resetSize(); - this.refresh(); + this.refresh({ type: 'all' }); } dispose(): void { @@ -453,8 +898,46 @@ export class CanvasRenderer { return this.provider.getPropertyValue?.(property) ?? ''; } - refresh() { - if (this._refreshRafId !== null) return; + getDebugMetrics(): CanvasRendererDebugMetrics { + return { + ...this._debugMetrics, + ...this._lastDebugSnapshot, + canvasMemoryMegabytes: + this._lastDebugSnapshot.canvasMemoryBytes / 1024 / 1024, + }; + } + + resetDebugMetrics() { + this._debugMetrics = { + refreshCount: 0, + coalescedRefreshCount: 0, + renderCount: 0, + totalRenderDurationMs: 0, + lastRenderDurationMs: 0, + maxRenderDurationMs: 0, + lastRenderMetrics: this._createRenderPassStats(), + dirtyLayerRenderCount: 0, + fallbackElementCount: 0, + }; + this._lastDebugSnapshot = { + canvasMemoryBytes: 0, + canvasMemorySnapshots: [], + canvasPixelCount: 0, + canvasLayerCount: 0, + pooledStackingCanvasCount: 0, + stackingCanvasCount: 0, + totalLayerCount: 0, + visibleStackingCanvasCount: 0, + }; + } + + refresh(target: RefreshTarget = { type: 'all' }) { + this._debugMetrics.refreshCount += 1; + this._invalidate(target); + if (this._refreshRafId !== null) { + this._debugMetrics.coalescedRefreshCount += 1; + return; + } this._refreshRafId = requestConnectedFrame(() => { this._refreshRafId = null; @@ -469,6 +952,6 @@ export class CanvasRenderer { overlay.setRenderer(null); this._overlays.delete(overlay); - this.refresh(); + this.refresh({ type: 'main' }); } } diff --git a/blocksuite/affine/blocks/surface/src/renderer/dom-renderer.ts b/blocksuite/affine/blocks/surface/src/renderer/dom-renderer.ts index b123a31245..b19b03b5c2 100644 --- a/blocksuite/affine/blocks/surface/src/renderer/dom-renderer.ts +++ b/blocksuite/affine/blocks/surface/src/renderer/dom-renderer.ts @@ -354,30 +354,37 @@ export class DomRenderer { this._disposables.add( surfaceModel.elementAdded.subscribe(payload => { this._markElementDirty(payload.id, UpdateType.ELEMENT_ADDED); + this._markViewportDirty(); this.refresh(); }) ); this._disposables.add( surfaceModel.elementRemoved.subscribe(payload => { this._markElementDirty(payload.id, UpdateType.ELEMENT_REMOVED); + this._markViewportDirty(); this.refresh(); }) ); this._disposables.add( surfaceModel.localElementAdded.subscribe(payload => { this._markElementDirty(payload.id, UpdateType.ELEMENT_ADDED); + this._markViewportDirty(); this.refresh(); }) ); this._disposables.add( surfaceModel.localElementDeleted.subscribe(payload => { this._markElementDirty(payload.id, UpdateType.ELEMENT_REMOVED); + this._markViewportDirty(); this.refresh(); }) ); this._disposables.add( surfaceModel.localElementUpdated.subscribe(payload => { this._markElementDirty(payload.model.id, UpdateType.ELEMENT_UPDATED); + if (payload.props['index'] || payload.props['groupId']) { + this._markViewportDirty(); + } this.refresh(); }) ); @@ -387,6 +394,9 @@ export class DomRenderer { // ignore externalXYWH update cause it's updated by the renderer if (payload.props['externalXYWH']) return; this._markElementDirty(payload.id, UpdateType.ELEMENT_UPDATED); + if (payload.props['index'] || payload.props['childIds']) { + this._markViewportDirty(); + } this.refresh(); }) ); diff --git a/blocksuite/affine/components/package.json b/blocksuite/affine/components/package.json index f8ff653f24..e7fdcfb021 100644 --- a/blocksuite/affine/components/package.json +++ b/blocksuite/affine/components/package.json @@ -19,7 +19,7 @@ "@blocksuite/sync": "workspace:*", "@floating-ui/dom": "^1.6.13", "@lit/context": "^1.1.2", - "@lottiefiles/dotlottie-wc": "^0.5.0", + "@lottiefiles/dotlottie-wc": "^0.9.4", "@preact/signals-core": "^1.8.0", "@toeverything/theme": "^1.1.23", "@types/hast": "^3.0.4", diff --git a/blocksuite/affine/components/src/link-preview/link.ts b/blocksuite/affine/components/src/link-preview/link.ts index 2fde9be69f..201fc209bc 100644 --- a/blocksuite/affine/components/src/link-preview/link.ts +++ b/blocksuite/affine/components/src/link-preview/link.ts @@ -1,4 +1,8 @@ -import { getHostName } from '@blocksuite/affine-shared/utils'; +import { + getHostName, + isValidUrl, + normalizeUrl, +} from '@blocksuite/affine-shared/utils'; import { PropTypes, requiredProperties } from '@blocksuite/std'; import { css, LitElement } from 'lit'; import { property } from 'lit/decorators.js'; @@ -44,15 +48,27 @@ export class LinkPreview extends LitElement { override render() { const { url } = this; + const normalizedUrl = normalizeUrl(url); + const safeUrl = + normalizedUrl && isValidUrl(normalizedUrl) ? normalizedUrl : null; + const hostName = getHostName(safeUrl ?? url); + + if (!safeUrl) { + return html` + + ${hostName} + + `; + } return html` - ${getHostName(url)} + ${hostName} `; } diff --git a/blocksuite/affine/data-view/package.json b/blocksuite/affine/data-view/package.json index 4663b15774..1ade75c73e 100644 --- a/blocksuite/affine/data-view/package.json +++ b/blocksuite/affine/data-view/package.json @@ -32,7 +32,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/ext-loader/package.json b/blocksuite/affine/ext-loader/package.json index d84cb86dca..f63e49ca2d 100644 --- a/blocksuite/affine/ext-loader/package.json +++ b/blocksuite/affine/ext-loader/package.json @@ -15,7 +15,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts" diff --git a/blocksuite/affine/ext-loader/vitest.config.ts b/blocksuite/affine/ext-loader/vitest.config.ts index 2667b267a6..404ee1c17d 100644 --- a/blocksuite/affine/ext-loader/vitest.config.ts +++ b/blocksuite/affine/ext-loader/vitest.config.ts @@ -8,7 +8,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/ext-loader', }, diff --git a/blocksuite/affine/gfx/brush/src/renderer/dom/brush.ts b/blocksuite/affine/gfx/brush/src/renderer/dom/brush.ts index 5f5dea340a..e197bd8b70 100644 --- a/blocksuite/affine/gfx/brush/src/renderer/dom/brush.ts +++ b/blocksuite/affine/gfx/brush/src/renderer/dom/brush.ts @@ -5,6 +5,8 @@ import { import type { BrushElementModel } from '@blocksuite/affine-model'; import { DefaultTheme } from '@blocksuite/affine-model'; +import { renderBrushLikeDom } from './shared'; + export const BrushDomRendererExtension = DomElementRendererExtension( 'brush', ( @@ -12,58 +14,11 @@ export const BrushDomRendererExtension = DomElementRendererExtension( domElement: HTMLElement, renderer: DomRenderer ) => { - const { zoom } = renderer.viewport; - const [, , w, h] = model.deserializedXYWH; - - // Early return if invalid dimensions - if (w <= 0 || h <= 0) { - return; - } - - // Early return if no commands - if (!model.commands) { - return; - } - - // Clear previous content - domElement.innerHTML = ''; - - // Get color value - const color = renderer.getColorValue(model.color, DefaultTheme.black, true); - - // Create SVG element - const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); - svg.style.position = 'absolute'; - svg.style.left = '0'; - svg.style.top = '0'; - svg.style.width = `${w * zoom}px`; - svg.style.height = `${h * zoom}px`; - svg.style.overflow = 'visible'; - svg.style.pointerEvents = 'none'; - svg.setAttribute('viewBox', `0 0 ${w} ${h}`); - - // Apply rotation transform - if (model.rotate !== 0) { - svg.style.transform = `rotate(${model.rotate}deg)`; - svg.style.transformOrigin = 'center'; - } - - // Create path element for the brush stroke - const pathElement = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); - pathElement.setAttribute('d', model.commands); - pathElement.setAttribute('fill', color); - pathElement.setAttribute('stroke', 'none'); - - svg.append(pathElement); - domElement.replaceChildren(svg); - - // Set element size and position - domElement.style.width = `${w * zoom}px`; - domElement.style.height = `${h * zoom}px`; - domElement.style.overflow = 'visible'; - domElement.style.pointerEvents = 'none'; + renderBrushLikeDom({ + model, + domElement, + renderer, + color: renderer.getColorValue(model.color, DefaultTheme.black, true), + }); } ); diff --git a/blocksuite/affine/gfx/brush/src/renderer/dom/highlighter.ts b/blocksuite/affine/gfx/brush/src/renderer/dom/highlighter.ts index e15f2a300e..5c410e9f72 100644 --- a/blocksuite/affine/gfx/brush/src/renderer/dom/highlighter.ts +++ b/blocksuite/affine/gfx/brush/src/renderer/dom/highlighter.ts @@ -5,6 +5,8 @@ import { import type { HighlighterElementModel } from '@blocksuite/affine-model'; import { DefaultTheme } from '@blocksuite/affine-model'; +import { renderBrushLikeDom } from './shared'; + export const HighlighterDomRendererExtension = DomElementRendererExtension( 'highlighter', ( @@ -12,62 +14,15 @@ export const HighlighterDomRendererExtension = DomElementRendererExtension( domElement: HTMLElement, renderer: DomRenderer ) => { - const { zoom } = renderer.viewport; - const [, , w, h] = model.deserializedXYWH; - - // Early return if invalid dimensions - if (w <= 0 || h <= 0) { - return; - } - - // Early return if no commands - if (!model.commands) { - return; - } - - // Clear previous content - domElement.innerHTML = ''; - - // Get color value - const color = renderer.getColorValue( - model.color, - DefaultTheme.hightlighterColor, - true - ); - - // Create SVG element - const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); - svg.style.position = 'absolute'; - svg.style.left = '0'; - svg.style.top = '0'; - svg.style.width = `${w * zoom}px`; - svg.style.height = `${h * zoom}px`; - svg.style.overflow = 'visible'; - svg.style.pointerEvents = 'none'; - svg.setAttribute('viewBox', `0 0 ${w} ${h}`); - - // Apply rotation transform - if (model.rotate !== 0) { - svg.style.transform = `rotate(${model.rotate}deg)`; - svg.style.transformOrigin = 'center'; - } - - // Create path element for the highlighter stroke - const pathElement = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); - pathElement.setAttribute('d', model.commands); - pathElement.setAttribute('fill', color); - pathElement.setAttribute('stroke', 'none'); - - svg.append(pathElement); - domElement.replaceChildren(svg); - - // Set element size and position - domElement.style.width = `${w * zoom}px`; - domElement.style.height = `${h * zoom}px`; - domElement.style.overflow = 'visible'; - domElement.style.pointerEvents = 'none'; + renderBrushLikeDom({ + model, + domElement, + renderer, + color: renderer.getColorValue( + model.color, + DefaultTheme.hightlighterColor, + true + ), + }); } ); diff --git a/blocksuite/affine/gfx/brush/src/renderer/dom/shared.ts b/blocksuite/affine/gfx/brush/src/renderer/dom/shared.ts new file mode 100644 index 0000000000..04198b003a --- /dev/null +++ b/blocksuite/affine/gfx/brush/src/renderer/dom/shared.ts @@ -0,0 +1,82 @@ +import type { DomRenderer } from '@blocksuite/affine-block-surface'; +import type { + BrushElementModel, + HighlighterElementModel, +} from '@blocksuite/affine-model'; + +const SVG_NS = 'http://www.w3.org/2000/svg'; + +type BrushLikeModel = BrushElementModel | HighlighterElementModel; + +type RetainedBrushDom = { + path: SVGPathElement; + svg: SVGSVGElement; +}; + +const retainedBrushDom = new WeakMap(); + +function clearBrushLikeDom(domElement: HTMLElement) { + retainedBrushDom.delete(domElement); + domElement.replaceChildren(); +} + +function getRetainedBrushDom(domElement: HTMLElement) { + const existing = retainedBrushDom.get(domElement); + + if (existing) { + return existing; + } + + const svg = document.createElementNS(SVG_NS, 'svg'); + svg.style.position = 'absolute'; + svg.style.left = '0'; + svg.style.top = '0'; + svg.style.overflow = 'visible'; + svg.style.pointerEvents = 'none'; + + const path = document.createElementNS(SVG_NS, 'path'); + path.setAttribute('stroke', 'none'); + svg.append(path); + + const retained = { svg, path }; + retainedBrushDom.set(domElement, retained); + domElement.replaceChildren(svg); + + return retained; +} + +export function renderBrushLikeDom({ + color, + domElement, + model, + renderer, +}: { + color: string; + domElement: HTMLElement; + model: BrushLikeModel; + renderer: DomRenderer; +}) { + const { zoom } = renderer.viewport; + const [, , w, h] = model.deserializedXYWH; + + if (w <= 0 || h <= 0 || !model.commands) { + clearBrushLikeDom(domElement); + return; + } + + const { path, svg } = getRetainedBrushDom(domElement); + + svg.style.width = `${w * zoom}px`; + svg.style.height = `${h * zoom}px`; + svg.style.transform = model.rotate === 0 ? '' : `rotate(${model.rotate}deg)`; + svg.style.transformOrigin = model.rotate === 0 ? '' : 'center'; + svg.setAttribute('viewBox', `0 0 ${w} ${h}`); + + path.setAttribute('d', model.commands); + path.setAttribute('fill', color); + + domElement.style.width = `${w * zoom}px`; + domElement.style.height = `${h * zoom}px`; + domElement.style.overflow = 'visible'; + domElement.style.pointerEvents = 'none'; +} diff --git a/blocksuite/affine/gfx/connector/src/renderer/dom-renderer.ts b/blocksuite/affine/gfx/connector/src/renderer/dom-renderer.ts index 46a6a07112..3f9d992344 100644 --- a/blocksuite/affine/gfx/connector/src/renderer/dom-renderer.ts +++ b/blocksuite/affine/gfx/connector/src/renderer/dom-renderer.ts @@ -14,6 +14,8 @@ import { PointLocation, SVGPathBuilder } from '@blocksuite/global/gfx'; import { isConnectorWithLabel } from '../connector-manager'; import { DEFAULT_ARROW_SIZE } from './utils'; +const SVG_NS = 'http://www.w3.org/2000/svg'; + interface PathBounds { minX: number; minY: number; @@ -21,6 +23,15 @@ interface PathBounds { maxY: number; } +type RetainedConnectorDom = { + defs: SVGDefsElement; + label: HTMLDivElement | null; + path: SVGPathElement; + svg: SVGSVGElement; +}; + +const retainedConnectorDom = new WeakMap(); + function calculatePathBounds(path: PointLocation[]): PathBounds { if (path.length === 0) { return { minX: 0, minY: 0, maxX: 0, maxY: 0 }; @@ -81,10 +92,7 @@ function createArrowMarker( strokeWidth: number, isStart: boolean = false ): SVGMarkerElement { - const marker = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'marker' - ); + const marker = document.createElementNS(SVG_NS, 'marker'); const size = DEFAULT_ARROW_SIZE * (strokeWidth / 2); marker.id = id; @@ -98,10 +106,7 @@ function createArrowMarker( switch (style) { case 'Arrow': { - const path = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); + const path = document.createElementNS(SVG_NS, 'path'); path.setAttribute( 'd', isStart ? 'M 20 5 L 10 10 L 20 15 Z' : 'M 0 5 L 10 10 L 0 15 Z' @@ -112,10 +117,7 @@ function createArrowMarker( break; } case 'Triangle': { - const path = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); + const path = document.createElementNS(SVG_NS, 'path'); path.setAttribute( 'd', isStart ? 'M 20 7 L 12 10 L 20 13 Z' : 'M 0 7 L 8 10 L 0 13 Z' @@ -126,10 +128,7 @@ function createArrowMarker( break; } case 'Circle': { - const circle = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'circle' - ); + const circle = document.createElementNS(SVG_NS, 'circle'); circle.setAttribute('cx', '10'); circle.setAttribute('cy', '10'); circle.setAttribute('r', '4'); @@ -139,10 +138,7 @@ function createArrowMarker( break; } case 'Diamond': { - const path = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); + const path = document.createElementNS(SVG_NS, 'path'); path.setAttribute('d', 'M 10 6 L 14 10 L 10 14 L 6 10 Z'); path.setAttribute('fill', color); path.setAttribute('stroke', color); @@ -154,13 +150,64 @@ function createArrowMarker( return marker; } +function clearRetainedConnectorDom(element: HTMLElement) { + retainedConnectorDom.delete(element); + element.replaceChildren(); +} + +function getRetainedConnectorDom(element: HTMLElement): RetainedConnectorDom { + const existing = retainedConnectorDom.get(element); + + if (existing) { + return existing; + } + + const svg = document.createElementNS(SVG_NS, 'svg'); + svg.style.position = 'absolute'; + svg.style.overflow = 'visible'; + svg.style.pointerEvents = 'none'; + + const defs = document.createElementNS(SVG_NS, 'defs'); + const path = document.createElementNS(SVG_NS, 'path'); + path.setAttribute('fill', 'none'); + path.setAttribute('stroke-linecap', 'round'); + path.setAttribute('stroke-linejoin', 'round'); + + svg.append(defs, path); + element.replaceChildren(svg); + + const retained = { + svg, + defs, + path, + label: null, + }; + retainedConnectorDom.set(element, retained); + + return retained; +} + +function getOrCreateLabelElement(retained: RetainedConnectorDom) { + if (retained.label) { + return retained.label; + } + + const label = document.createElement('div'); + retained.svg.insertAdjacentElement('afterend', label); + retained.label = label; + + return label; +} + function renderConnectorLabel( model: ConnectorElementModel, - container: HTMLElement, + retained: RetainedConnectorDom, renderer: DomRenderer, zoom: number ) { if (!isConnectorWithLabel(model) || !model.labelXYWH) { + retained.label?.remove(); + retained.label = null; return; } @@ -176,8 +223,7 @@ function renderConnectorLabel( }, } = model; - // Create label element - const labelElement = document.createElement('div'); + const labelElement = getOrCreateLabelElement(retained); labelElement.style.position = 'absolute'; labelElement.style.left = `${lx * zoom}px`; labelElement.style.top = `${ly * zoom}px`; @@ -210,11 +256,7 @@ function renderConnectorLabel( labelElement.style.wordWrap = 'break-word'; // Add text content - if (model.text) { - labelElement.textContent = model.text.toString(); - } - - container.append(labelElement); + labelElement.textContent = model.text ? model.text.toString() : ''; } /** @@ -241,14 +283,13 @@ export const connectorBaseDomRenderer = ( stroke, } = model; - // Clear previous content - element.innerHTML = ''; - - // Early return if no path points if (!points || points.length < 2) { + clearRetainedConnectorDom(element); return; } + const retained = getRetainedConnectorDom(element); + // Calculate bounds for the SVG viewBox const pathBounds = calculatePathBounds(points); const padding = Math.max(strokeWidth * 2, 20); // Add padding for arrows @@ -257,8 +298,7 @@ export const connectorBaseDomRenderer = ( const offsetX = pathBounds.minX - padding; const offsetY = pathBounds.minY - padding; - // Create SVG element - const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + const { defs, path, svg } = retained; svg.style.position = 'absolute'; svg.style.left = `${offsetX * zoom}px`; svg.style.top = `${offsetY * zoom}px`; @@ -268,49 +308,43 @@ export const connectorBaseDomRenderer = ( svg.style.pointerEvents = 'none'; svg.setAttribute('viewBox', `0 0 ${svgWidth / zoom} ${svgHeight / zoom}`); - // Create defs for markers - const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs'); - svg.append(defs); - const strokeColor = renderer.getColorValue( stroke, DefaultTheme.connectorColor, true ); - // Create markers for endpoints + const markers: SVGMarkerElement[] = []; let startMarkerId = ''; let endMarkerId = ''; if (frontEndpointStyle !== 'None') { startMarkerId = `start-marker-${model.id}`; - const startMarker = createArrowMarker( - startMarkerId, - frontEndpointStyle, - strokeColor, - strokeWidth, - true + markers.push( + createArrowMarker( + startMarkerId, + frontEndpointStyle, + strokeColor, + strokeWidth, + true + ) ); - defs.append(startMarker); } if (rearEndpointStyle !== 'None') { endMarkerId = `end-marker-${model.id}`; - const endMarker = createArrowMarker( - endMarkerId, - rearEndpointStyle, - strokeColor, - strokeWidth, - false + markers.push( + createArrowMarker( + endMarkerId, + rearEndpointStyle, + strokeColor, + strokeWidth, + false + ) ); - defs.append(endMarker); } - // Create path element - const pathElement = document.createElementNS( - 'http://www.w3.org/2000/svg', - 'path' - ); + defs.replaceChildren(...markers); // Adjust points relative to the SVG coordinate system const adjustedPoints = points.map(point => { @@ -334,29 +368,25 @@ export const connectorBaseDomRenderer = ( }); const pathData = createConnectorPath(adjustedPoints, mode); - pathElement.setAttribute('d', pathData); - pathElement.setAttribute('stroke', strokeColor); - pathElement.setAttribute('stroke-width', String(strokeWidth)); - pathElement.setAttribute('fill', 'none'); - pathElement.setAttribute('stroke-linecap', 'round'); - pathElement.setAttribute('stroke-linejoin', 'round'); - - // Apply stroke style + path.setAttribute('d', pathData); + path.setAttribute('stroke', strokeColor); + path.setAttribute('stroke-width', String(strokeWidth)); if (strokeStyle === 'dash') { - pathElement.setAttribute('stroke-dasharray', '12,12'); + path.setAttribute('stroke-dasharray', '12,12'); + } else { + path.removeAttribute('stroke-dasharray'); } - - // Apply markers if (startMarkerId) { - pathElement.setAttribute('marker-start', `url(#${startMarkerId})`); + path.setAttribute('marker-start', `url(#${startMarkerId})`); + } else { + path.removeAttribute('marker-start'); } if (endMarkerId) { - pathElement.setAttribute('marker-end', `url(#${endMarkerId})`); + path.setAttribute('marker-end', `url(#${endMarkerId})`); + } else { + path.removeAttribute('marker-end'); } - svg.append(pathElement); - element.append(svg); - // Set element size and position element.style.width = `${model.w * zoom}px`; element.style.height = `${model.h * zoom}px`; @@ -370,7 +400,11 @@ export const connectorDomRenderer = ( renderer: DomRenderer ): void => { connectorBaseDomRenderer(model, element, renderer); - renderConnectorLabel(model, element, renderer, renderer.viewport.zoom); + + const retained = retainedConnectorDom.get(element); + if (!retained) return; + + renderConnectorLabel(model, retained, renderer, renderer.viewport.zoom); }; /** diff --git a/blocksuite/affine/gfx/group/package.json b/blocksuite/affine/gfx/group/package.json index d4ea825f95..3388ccf65a 100644 --- a/blocksuite/affine/gfx/group/package.json +++ b/blocksuite/affine/gfx/group/package.json @@ -34,7 +34,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/gfx/pointer/package.json b/blocksuite/affine/gfx/pointer/package.json index 737f420538..fdf0a0e40b 100644 --- a/blocksuite/affine/gfx/pointer/package.json +++ b/blocksuite/affine/gfx/pointer/package.json @@ -32,7 +32,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/gfx/shape/src/element-renderer/shape-dom/index.ts b/blocksuite/affine/gfx/shape/src/element-renderer/shape-dom/index.ts index 749ad67660..ff4eaa713f 100644 --- a/blocksuite/affine/gfx/shape/src/element-renderer/shape-dom/index.ts +++ b/blocksuite/affine/gfx/shape/src/element-renderer/shape-dom/index.ts @@ -6,6 +6,37 @@ import { SVGShapeBuilder } from '@blocksuite/global/gfx'; import { manageClassNames, setStyles } from './utils'; +const SVG_NS = 'http://www.w3.org/2000/svg'; + +type RetainedShapeDom = { + polygon: SVGPolygonElement | null; + svg: SVGSVGElement | null; + text: HTMLDivElement | null; +}; + +type RetainedShapeSvg = { + polygon: SVGPolygonElement; + svg: SVGSVGElement; +}; + +const retainedShapeDom = new WeakMap(); + +function getRetainedShapeDom(element: HTMLElement): RetainedShapeDom { + const existing = retainedShapeDom.get(element); + + if (existing) { + return existing; + } + + const retained = { + svg: null, + polygon: null, + text: null, + }; + retainedShapeDom.set(element, retained); + return retained; +} + function applyShapeSpecificStyles( model: ShapeElementModel, element: HTMLElement, @@ -14,10 +45,6 @@ function applyShapeSpecificStyles( // Reset properties that might be set by different shape types element.style.removeProperty('clip-path'); element.style.removeProperty('border-radius'); - // Clear DOM for shapes that don't use SVG, or if type changes from SVG-based to non-SVG-based - if (model.shapeType !== 'diamond' && model.shapeType !== 'triangle') { - while (element.firstChild) element.firstChild.remove(); - } switch (model.shapeType) { case 'rect': { @@ -42,6 +69,54 @@ function applyShapeSpecificStyles( // No 'else' needed to clear styles, as they are reset at the beginning of the function. } +function getOrCreateSvg( + retained: RetainedShapeDom, + element: HTMLElement +): RetainedShapeSvg { + if (retained.svg && retained.polygon) { + return { + svg: retained.svg, + polygon: retained.polygon, + }; + } + + const svg = document.createElementNS(SVG_NS, 'svg'); + svg.setAttribute('width', '100%'); + svg.setAttribute('height', '100%'); + svg.setAttribute('preserveAspectRatio', 'none'); + + const polygon = document.createElementNS(SVG_NS, 'polygon'); + svg.append(polygon); + + retained.svg = svg; + retained.polygon = polygon; + element.prepend(svg); + + return { svg, polygon }; +} + +function removeSvg(retained: RetainedShapeDom) { + retained.svg?.remove(); + retained.svg = null; + retained.polygon = null; +} + +function getOrCreateText(retained: RetainedShapeDom, element: HTMLElement) { + if (retained.text) { + return retained.text; + } + + const text = document.createElement('div'); + retained.text = text; + element.append(text); + return text; +} + +function removeText(retained: RetainedShapeDom) { + retained.text?.remove(); + retained.text = null; +} + function applyBorderStyles( model: ShapeElementModel, element: HTMLElement, @@ -99,8 +174,7 @@ export const shapeDomRenderer = ( const { zoom } = renderer.viewport; const unscaledWidth = model.w; const unscaledHeight = model.h; - - const newChildren: Element[] = []; + const retained = getRetainedShapeDom(element); const fillColor = renderer.getColorValue( model.fillColor, @@ -124,6 +198,7 @@ export const shapeDomRenderer = ( // For diamond and triangle, fill and border are handled by inline SVG element.style.border = 'none'; // Ensure no standard CSS border interferes element.style.backgroundColor = 'transparent'; // Host element is transparent + const { polygon, svg } = getOrCreateSvg(retained, element); const strokeW = model.strokeWidth; @@ -155,37 +230,30 @@ export const shapeDomRenderer = ( // Determine fill color const finalFillColor = model.filled ? fillColor : 'transparent'; - // Build SVG safely with DOM-API - const SVG_NS = 'http://www.w3.org/2000/svg'; - const svg = document.createElementNS(SVG_NS, 'svg'); - svg.setAttribute('width', '100%'); - svg.setAttribute('height', '100%'); svg.setAttribute('viewBox', `0 0 ${unscaledWidth} ${unscaledHeight}`); - svg.setAttribute('preserveAspectRatio', 'none'); - - const polygon = document.createElementNS(SVG_NS, 'polygon'); polygon.setAttribute('points', svgPoints); polygon.setAttribute('fill', finalFillColor); polygon.setAttribute('stroke', finalStrokeColor); polygon.setAttribute('stroke-width', String(strokeW)); if (finalStrokeDasharray !== 'none') { polygon.setAttribute('stroke-dasharray', finalStrokeDasharray); + } else { + polygon.removeAttribute('stroke-dasharray'); } - svg.append(polygon); - - newChildren.push(svg); } else { // Standard rendering for other shapes (e.g., rect, ellipse) - // innerHTML was already cleared by applyShapeSpecificStyles if necessary + removeSvg(retained); element.style.backgroundColor = model.filled ? fillColor : 'transparent'; applyBorderStyles(model, element, strokeColor, zoom); // Uses standard CSS border } if (model.textDisplay && model.text) { const str = model.text.toString(); - const textElement = document.createElement('div'); + const textElement = getOrCreateText(retained, element); if (isRTL(str)) { textElement.dir = 'rtl'; + } else { + textElement.removeAttribute('dir'); } textElement.style.position = 'absolute'; textElement.style.inset = '0'; @@ -210,12 +278,10 @@ export const shapeDomRenderer = ( true ); textElement.textContent = str; - newChildren.push(textElement); + } else { + removeText(retained); } - // Replace existing children to avoid memory leaks - element.replaceChildren(...newChildren); - applyTransformStyles(model, element); manageClassNames(model, element); diff --git a/blocksuite/affine/inlines/comment/package.json b/blocksuite/affine/inlines/comment/package.json index 3e58263309..a33e0b2586 100644 --- a/blocksuite/affine/inlines/comment/package.json +++ b/blocksuite/affine/inlines/comment/package.json @@ -29,7 +29,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/inlines/footnote/package.json b/blocksuite/affine/inlines/footnote/package.json index 3c826cb11e..849f2e5157 100644 --- a/blocksuite/affine/inlines/footnote/package.json +++ b/blocksuite/affine/inlines/footnote/package.json @@ -34,7 +34,8 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "@vitest/browser-playwright": "^4.0.18", + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/affine/inlines/footnote/src/footnote-node/footnote-node.ts b/blocksuite/affine/inlines/footnote/src/footnote-node/footnote-node.ts index 8a02e583c1..f0bbcf2056 100644 --- a/blocksuite/affine/inlines/footnote/src/footnote-node/footnote-node.ts +++ b/blocksuite/affine/inlines/footnote/src/footnote-node/footnote-node.ts @@ -4,6 +4,7 @@ import type { FootNote } from '@blocksuite/affine-model'; import { CitationProvider } from '@blocksuite/affine-shared/services'; import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme'; import type { AffineTextAttributes } from '@blocksuite/affine-shared/types'; +import { isValidUrl, normalizeUrl } from '@blocksuite/affine-shared/utils'; import { WithDisposable } from '@blocksuite/global/lit'; import { BlockSelection, @@ -152,7 +153,9 @@ export class AffineFootnoteNode extends WithDisposable(ShadowlessElement) { }; private readonly _handleUrlReference = (url: string) => { - window.open(url, '_blank'); + const normalizedUrl = normalizeUrl(url); + if (!normalizedUrl || !isValidUrl(normalizedUrl)) return; + window.open(normalizedUrl, '_blank', 'noopener,noreferrer'); }; private readonly _updateFootnoteAttributes = (footnote: FootNote) => { diff --git a/blocksuite/affine/inlines/footnote/vitest.config.ts b/blocksuite/affine/inlines/footnote/vitest.config.ts index 05591362f9..5cb5ee22ac 100644 --- a/blocksuite/affine/inlines/footnote/vitest.config.ts +++ b/blocksuite/affine/inlines/footnote/vitest.config.ts @@ -1,3 +1,4 @@ +import { playwright } from '@vitest/browser-playwright'; import { defineConfig } from 'vitest/config'; export default defineConfig({ @@ -8,10 +9,9 @@ export default defineConfig({ browser: { enabled: true, headless: true, - name: 'chromium', - provider: 'playwright', + instances: [{ browser: 'chromium' }], + provider: playwright(), isolate: false, - providerOptions: {}, }, include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, diff --git a/blocksuite/affine/model/src/elements/connector/connector.ts b/blocksuite/affine/model/src/elements/connector/connector.ts index a242f5a350..86917dcac3 100644 --- a/blocksuite/affine/model/src/elements/connector/connector.ts +++ b/blocksuite/affine/model/src/elements/connector/connector.ts @@ -177,6 +177,11 @@ export class ConnectorElementModel extends GfxPrimitiveElementModel(p => [p[0], p[1]]); const point = Polyline.pointAt(points, offsetDistance); @@ -300,6 +313,10 @@ export class ConnectorElementModel extends GfxPrimitiveElementModel { + let caretRangeFromPointSpy: MockInstance< + (clientX: number, clientY: number) => Range | null + >; + let resetNativeSelectionSpy: MockInstance<(range: Range | null) => void>; + + beforeEach(() => { + caretRangeFromPointSpy = vi.spyOn( + PointToRangeUtils.api, + 'caretRangeFromPoint' + ); + resetNativeSelectionSpy = vi.spyOn( + PointToRangeUtils.api, + 'resetNativeSelection' + ); + }); + + it('does nothing if caretRangeFromPoint returns null', () => { + caretRangeFromPointSpy.mockReturnValue(null); + + handleNativeRangeAtPoint(10, 10); + expect(resetNativeSelectionSpy).not.toHaveBeenCalled(); + }); + + it('keeps range untouched if startContainer is a Text node', () => { + const div = document.createElement('div'); + div.textContent = 'hello'; + + const text = div.firstChild!; + + const range = document.createRange(); + range.setStart(text, 2); + range.collapse(true); + + caretRangeFromPointSpy.mockReturnValue(range); + + handleNativeRangeAtPoint(10, 10); + + expect(range.startContainer).toBe(text); + expect(range.startOffset).toBe(2); + expect(resetNativeSelectionSpy).toHaveBeenCalled(); + }); + + it('moves caret into direct text child when clicking element', () => { + const div = document.createElement('div'); + div.append('hello'); + + const range = document.createRange(); + range.setStart(div, 1); + range.collapse(true); + + caretRangeFromPointSpy.mockReturnValue(range); + + handleNativeRangeAtPoint(10, 10); + + expect(range.startContainer.nodeType).toBe(Node.TEXT_NODE); + expect(range.startContainer.textContent).toBe('hello'); + expect(range.startOffset).toBe(5); + expect(resetNativeSelectionSpy).toHaveBeenCalled(); + }); + + it('moves caret to last meaningful text inside nested element', () => { + const div = document.createElement('div'); + div.innerHTML = `abc`; + + const range = document.createRange(); + range.setStart(div, 2); + range.collapse(true); + + caretRangeFromPointSpy.mockReturnValue(range); + + handleNativeRangeAtPoint(10, 10); + + expect(range.startContainer.nodeType).toBe(Node.TEXT_NODE); + expect(range.startContainer.textContent).toBe('c'); + expect(range.startOffset).toBe(1); + expect(resetNativeSelectionSpy).toHaveBeenCalled(); + }); + + it('falls back to searching startContainer when offset element has no text', () => { + const div = document.createElement('div'); + div.innerHTML = `ok`; + + const range = document.createRange(); + range.setStart(div, 1); + range.collapse(true); + + caretRangeFromPointSpy.mockReturnValue(range); + + handleNativeRangeAtPoint(10, 10); + + expect(range.startContainer.textContent).toBe('ok'); + expect(range.startOffset).toBe(2); + expect(resetNativeSelectionSpy).toHaveBeenCalled(); + }); +}); diff --git a/blocksuite/affine/shared/src/utils/dom/point-to-range.ts b/blocksuite/affine/shared/src/utils/dom/point-to-range.ts index d7e879c844..1ec514f90d 100644 --- a/blocksuite/affine/shared/src/utils/dom/point-to-range.ts +++ b/blocksuite/affine/shared/src/utils/dom/point-to-range.ts @@ -88,11 +88,73 @@ export function getCurrentNativeRange(selection = window.getSelection()) { return selection.getRangeAt(0); } +// functions need to be mocked in unit-test +export const api = { + caretRangeFromPoint, + resetNativeSelection, +}; + export function handleNativeRangeAtPoint(x: number, y: number) { - const range = caretRangeFromPoint(x, y); + const range = api.caretRangeFromPoint(x, y); + if (range) { + normalizeCaretRange(range); + } + const startContainer = range?.startContainer; // click on rich text if (startContainer instanceof Node) { - resetNativeSelection(range); + api.resetNativeSelection(range); + } +} + +function lastMeaningfulTextNode(node: Node) { + const walker = document.createTreeWalker(node, NodeFilter.SHOW_TEXT, { + acceptNode(node) { + return node.textContent && node.textContent?.trim().length > 0 + ? NodeFilter.FILTER_ACCEPT + : NodeFilter.FILTER_REJECT; + }, + }); + + let last = null; + while (walker.nextNode()) { + last = walker.currentNode; + } + return last; +} + +function normalizeCaretRange(range: Range) { + let { startContainer, startOffset } = range; + if (startContainer.nodeType === Node.TEXT_NODE) return; + + // Try to find text in the element at `startOffset` + const offsetEl = + startOffset > 0 + ? startContainer.childNodes[startOffset - 1] + : startContainer.childNodes[0]; + if (offsetEl) { + if (offsetEl.nodeType === Node.TEXT_NODE) { + range.setStart( + offsetEl, + startOffset > 0 ? (offsetEl.textContent?.length ?? 0) : 0 + ); + range.collapse(true); + return; + } + + const text = lastMeaningfulTextNode(offsetEl); + if (text) { + range.setStart(text, text.textContent?.length ?? 0); + range.collapse(true); + return; + } + } + + // Fallback, try to find text in startContainer + const text = lastMeaningfulTextNode(startContainer); + if (text) { + range.setStart(text, text.textContent?.length ?? 0); + range.collapse(true); + return; } } diff --git a/blocksuite/affine/shared/src/utils/print-to-pdf.ts b/blocksuite/affine/shared/src/utils/print-to-pdf.ts index b3acb536cd..61a656e5de 100644 --- a/blocksuite/affine/shared/src/utils/print-to-pdf.ts +++ b/blocksuite/affine/shared/src/utils/print-to-pdf.ts @@ -49,6 +49,9 @@ export async function printToPdf( --affine-background-primary: #fff !important; --affine-background-secondary: #fff !important; --affine-background-tertiary: #fff !important; + --affine-background-code-block: #f5f5f5 !important; + --affine-quote-color: #e3e3e3 !important; + --affine-border-color: #e3e3e3 !important; } body, [data-theme='dark'] { color: #000 !important; diff --git a/blocksuite/affine/shared/src/utils/url.ts b/blocksuite/affine/shared/src/utils/url.ts index f1a67a2ab8..4c0f97d273 100644 --- a/blocksuite/affine/shared/src/utils/url.ts +++ b/blocksuite/affine/shared/src/utils/url.ts @@ -24,6 +24,11 @@ const toURL = (str: string) => { } }; +const hasAllowedScheme = (url: URL) => { + const protocol = url.protocol.slice(0, -1).toLowerCase(); + return ALLOWED_SCHEMES.has(protocol); +}; + function resolveURL(str: string, baseUrl: string, padded = false) { const url = toURL(str); if (!url) return null; @@ -61,6 +66,7 @@ export function normalizeUrl(str: string) { // Formatted if (url) { + if (!hasAllowedScheme(url)) return ''; if (!str.endsWith('/') && url.href.endsWith('/')) { return url.href.substring(0, url.href.length - 1); } diff --git a/blocksuite/affine/shared/vitest.config.ts b/blocksuite/affine/shared/vitest.config.ts index 9c1c45d368..2b55b42dfd 100644 --- a/blocksuite/affine/shared/vitest.config.ts +++ b/blocksuite/affine/shared/vitest.config.ts @@ -9,7 +9,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 1000, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', // or 'istanbul' reporter: ['lcov'], reportsDirectory: '../../../.coverage/affine-shared', }, diff --git a/blocksuite/framework/global/package.json b/blocksuite/framework/global/package.json index a8d9dae4db..d6ea27a6bc 100644 --- a/blocksuite/framework/global/package.json +++ b/blocksuite/framework/global/package.json @@ -62,7 +62,7 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "version": "0.26.3" } diff --git a/blocksuite/framework/global/src/__tests__/curve.unit.spec.ts b/blocksuite/framework/global/src/__tests__/curve.unit.spec.ts new file mode 100644 index 0000000000..abfd493b2e --- /dev/null +++ b/blocksuite/framework/global/src/__tests__/curve.unit.spec.ts @@ -0,0 +1,22 @@ +import { describe, expect, test } from 'vitest'; + +import { getBezierParameters } from '../gfx/curve.js'; +import { PointLocation } from '../gfx/model/index.js'; + +describe('getBezierParameters', () => { + test('should handle empty path', () => { + expect(() => getBezierParameters([])).not.toThrow(); + expect(getBezierParameters([])).toEqual([ + new PointLocation(), + new PointLocation(), + new PointLocation(), + new PointLocation(), + ]); + }); + + test('should handle single-point path', () => { + const point = new PointLocation([10, 20]); + + expect(getBezierParameters([point])).toEqual([point, point, point, point]); + }); +}); diff --git a/blocksuite/framework/global/src/gfx/curve.ts b/blocksuite/framework/global/src/gfx/curve.ts index ac08d00fbd..4eebbb86c9 100644 --- a/blocksuite/framework/global/src/gfx/curve.ts +++ b/blocksuite/framework/global/src/gfx/curve.ts @@ -142,6 +142,11 @@ export function getBezierNearestPoint( export function getBezierParameters( points: PointLocation[] ): BezierCurveParameters { + if (points.length === 0) { + const point = new PointLocation(); + return [point, point, point, point]; + } + // Fallback for degenerate Bezier curve (all points are at the same position) if (points.length === 1) { const point = points[0]; diff --git a/blocksuite/framework/global/vitest.config.ts b/blocksuite/framework/global/vitest.config.ts index 09ab57ed83..80a84e0aec 100644 --- a/blocksuite/framework/global/vitest.config.ts +++ b/blocksuite/framework/global/vitest.config.ts @@ -5,7 +5,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/global', }, diff --git a/blocksuite/framework/std/package.json b/blocksuite/framework/std/package.json index 2790f50e26..c324f348ce 100644 --- a/blocksuite/framework/std/package.json +++ b/blocksuite/framework/std/package.json @@ -33,7 +33,8 @@ "zod": "^3.25.76" }, "devDependencies": { - "vitest": "^3.2.4" + "@vitest/browser-playwright": "^4.0.18", + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/framework/std/src/gfx/layer.ts b/blocksuite/framework/std/src/gfx/layer.ts index 4ee2f11e99..12a9522928 100644 --- a/blocksuite/framework/std/src/gfx/layer.ts +++ b/blocksuite/framework/std/src/gfx/layer.ts @@ -596,7 +596,7 @@ export class LayerManager extends GfxExtension { private _updateLayer( element: GfxModel | GfxLocalElementModel, props?: Record, - oldValues?: Record + _oldValues?: Record ) { const modelType = this._getModelType(element); const isLocalElem = element instanceof GfxLocalElementModel; @@ -613,16 +613,7 @@ export class LayerManager extends GfxExtension { }; if (shouldUpdateGroupChildren) { - const group = element as GfxModel & GfxGroupCompatibleInterface; - const oldChildIds = childIdsChanged - ? Array.isArray(oldValues?.['childIds']) - ? (oldValues['childIds'] as string[]) - : this._groupChildSnapshot.get(group.id) - : undefined; - - const relatedElements = this._getRelatedGroupElements(group, oldChildIds); - this._refreshElementsInLayer(relatedElements); - this._syncGroupChildSnapshot(group); + this._reset(); return true; } diff --git a/blocksuite/framework/std/src/view/element/gfx-block-component.ts b/blocksuite/framework/std/src/view/element/gfx-block-component.ts index 4f9ee7fe0e..9231eb754f 100644 --- a/blocksuite/framework/std/src/view/element/gfx-block-component.ts +++ b/blocksuite/framework/std/src/view/element/gfx-block-component.ts @@ -31,6 +31,13 @@ function updateTransform(element: GfxBlockComponent) { element.style.transform = element.getCSSTransform(); } +function updateZIndex(element: GfxBlockComponent) { + const zIndex = element.toZIndex(); + if (element.style.zIndex !== zIndex) { + element.style.zIndex = zIndex; + } +} + function updateBlockVisibility(view: GfxBlockComponent) { if (view.transformState$.value === 'active') { view.style.visibility = 'visible'; @@ -58,14 +65,22 @@ function handleGfxConnection(instance: GfxBlockComponent) { instance.store.slots.blockUpdated.subscribe(({ type, id }) => { if (id === instance.model.id && type === 'update') { updateTransform(instance); + updateZIndex(instance); } }) ); + instance.disposables.add( + instance.gfx.layer.slots.layerUpdated.subscribe(() => { + updateZIndex(instance); + }) + ); + instance.disposables.add( effect(() => { updateBlockVisibility(instance); updateTransform(instance); + updateZIndex(instance); }) ); } @@ -105,17 +120,23 @@ export abstract class GfxBlockComponent< onBoxSelected(_: BoxSelectionContext) {} + getCSSScaleVal(): number { + const viewport = this.gfx.viewport; + const { zoom, viewScale } = viewport; + return zoom / viewScale; + } + getCSSTransform() { const viewport = this.gfx.viewport; - const { translateX, translateY, zoom } = viewport; + const { translateX, translateY, zoom, viewScale } = viewport; const bound = Bound.deserialize(this.model.xywh); - const scaledX = bound.x * zoom; - const scaledY = bound.y * zoom; + const scaledX = (bound.x * zoom) / viewScale; + const scaledY = (bound.y * zoom) / viewScale; const deltaX = scaledX - bound.x; const deltaY = scaledY - bound.y; - return `translate(${translateX + deltaX}px, ${translateY + deltaY}px) scale(${zoom})`; + return `translate(${translateX / viewScale + deltaX}px, ${translateY / viewScale + deltaY}px) scale(${this.getCSSScaleVal()})`; } getRenderingRect() { @@ -219,18 +240,12 @@ export function toGfxBlockComponent< handleGfxConnection(this); } - // eslint-disable-next-line sonarjs/no-identical-functions + getCSSScaleVal(): number { + return GfxBlockComponent.prototype.getCSSScaleVal.call(this); + } + getCSSTransform() { - const viewport = this.gfx.viewport; - const { translateX, translateY, zoom } = viewport; - const bound = Bound.deserialize(this.model.xywh); - - const scaledX = bound.x * zoom; - const scaledY = bound.y * zoom; - const deltaX = scaledX - bound.x; - const deltaY = scaledY - bound.y; - - return `translate(${translateX + deltaX}px, ${translateY + deltaY}px) scale(${zoom})`; + return GfxBlockComponent.prototype.getCSSTransform.call(this); } // eslint-disable-next-line sonarjs/no-identical-functions diff --git a/blocksuite/framework/std/vitest.config.ts b/blocksuite/framework/std/vitest.config.ts index 5bfeaf0148..820078f83a 100644 --- a/blocksuite/framework/std/vitest.config.ts +++ b/blocksuite/framework/std/vitest.config.ts @@ -1,3 +1,4 @@ +import { playwright } from '@vitest/browser-playwright'; import { defineConfig } from 'vitest/config'; export default defineConfig({ @@ -8,15 +9,14 @@ export default defineConfig({ browser: { enabled: true, headless: true, - name: 'chromium', - provider: 'playwright', + instances: [{ browser: 'chromium' }], + provider: playwright(), isolate: false, - providerOptions: {}, }, include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/std', }, diff --git a/blocksuite/framework/store/package.json b/blocksuite/framework/store/package.json index eee253ecb1..7bebb8120f 100644 --- a/blocksuite/framework/store/package.json +++ b/blocksuite/framework/store/package.json @@ -29,7 +29,7 @@ "devDependencies": { "@types/lodash.clonedeep": "^4.5.9", "@types/lodash.merge": "^4.6.9", - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "exports": { ".": "./src/index.ts", diff --git a/blocksuite/framework/store/src/index.ts b/blocksuite/framework/store/src/index.ts index 26bd467296..4ce50eee90 100644 --- a/blocksuite/framework/store/src/index.ts +++ b/blocksuite/framework/store/src/index.ts @@ -7,15 +7,11 @@ export * from './transformer'; export { type IdGenerator, nanoid, uuidv4 } from './utils/id-generator'; export * from './yjs'; -const env = ( - typeof globalThis !== 'undefined' - ? globalThis - : typeof window !== 'undefined' - ? window - : typeof global !== 'undefined' - ? global - : {} -) as Record; +const env = (typeof globalThis !== 'undefined' + ? globalThis + : typeof window !== 'undefined' + ? window + : {}) as unknown as Record; const importIdentifier = '__ $BLOCKSUITE_STORE$ __'; if (env[importIdentifier] === true) { diff --git a/blocksuite/framework/store/vitest.config.ts b/blocksuite/framework/store/vitest.config.ts index a5c10c8fc5..07e01b9f46 100644 --- a/blocksuite/framework/store/vitest.config.ts +++ b/blocksuite/framework/store/vitest.config.ts @@ -8,7 +8,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/store', }, diff --git a/blocksuite/framework/sync/package.json b/blocksuite/framework/sync/package.json index e66d9b769f..f48aefc2f1 100644 --- a/blocksuite/framework/sync/package.json +++ b/blocksuite/framework/sync/package.json @@ -19,7 +19,7 @@ "y-protocols": "^1.0.6" }, "devDependencies": { - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "peerDependencies": { "yjs": "*" diff --git a/blocksuite/framework/sync/vitest.config.ts b/blocksuite/framework/sync/vitest.config.ts index f884bd8ac8..187488790a 100644 --- a/blocksuite/framework/sync/vitest.config.ts +++ b/blocksuite/framework/sync/vitest.config.ts @@ -5,7 +5,7 @@ export default defineConfig({ include: ['src/__tests__/**/*.unit.spec.ts'], testTimeout: 500, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../../.coverage/sync', }, diff --git a/blocksuite/integration-test/package.json b/blocksuite/integration-test/package.json index b010902722..434b68bb6b 100644 --- a/blocksuite/integration-test/package.json +++ b/blocksuite/integration-test/package.json @@ -6,7 +6,7 @@ "dev": "vite", "build": "tsc", "test:unit": "vitest --browser.headless --run", - "test:debug": "PWDEBUG=1 npx vitest" + "test:debug": "PWDEBUG=1 npx vitest --browser.headless=false" }, "sideEffects": false, "keywords": [], @@ -17,7 +17,7 @@ "@blocksuite/icons": "^2.2.17", "@floating-ui/dom": "^1.6.13", "@lit/context": "^1.1.3", - "@lottiefiles/dotlottie-wc": "^0.5.0", + "@lottiefiles/dotlottie-wc": "^0.9.4", "@preact/signals-core": "^1.8.0", "@toeverything/theme": "^1.1.23", "@vanilla-extract/css": "^1.17.0", @@ -41,10 +41,11 @@ ], "devDependencies": { "@vanilla-extract/vite-plugin": "^5.0.0", + "@vitest/browser-playwright": "^4.0.18", "vite": "^7.2.7", "vite-plugin-istanbul": "^7.2.1", "vite-plugin-wasm": "^3.5.0", - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "version": "0.26.3" } diff --git a/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-remove-connector-DOM-node-when-element-is-deleted-1.png b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-remove-connector-DOM-node-when-element-is-deleted-1.png new file mode 100644 index 0000000000..eac57000ba Binary files /dev/null and b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-remove-connector-DOM-node-when-element-is-deleted-1.png differ diff --git a/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-a-connector-element-as-a-DOM-node-1.png b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-a-connector-element-as-a-DOM-node-1.png new file mode 100644 index 0000000000..eac57000ba Binary files /dev/null and b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-a-connector-element-as-a-DOM-node-1.png differ diff --git a/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-arrow-endpoints-1.png b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-arrow-endpoints-1.png new file mode 100644 index 0000000000..eac57000ba Binary files /dev/null and b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-arrow-endpoints-1.png differ diff --git a/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-different-stroke-styles-1.png b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-different-stroke-styles-1.png new file mode 100644 index 0000000000..eac57000ba Binary files /dev/null and b/blocksuite/integration-test/src/__tests__/edgeless/__screenshots__/connector-dom.spec.ts/Connector-rendering-with-DOM-renderer-should-render-connector-with-different-stroke-styles-1.png differ diff --git a/blocksuite/integration-test/src/__tests__/edgeless/layer.spec.ts b/blocksuite/integration-test/src/__tests__/edgeless/layer.spec.ts index 1cbb9ec907..cfb1835966 100644 --- a/blocksuite/integration-test/src/__tests__/edgeless/layer.spec.ts +++ b/blocksuite/integration-test/src/__tests__/edgeless/layer.spec.ts @@ -6,6 +6,7 @@ import type { import { ungroupCommand } from '@blocksuite/affine/gfx/group'; import type { GroupElementModel, + MindmapElementModel, NoteBlockModel, } from '@blocksuite/affine/model'; import { generateKeyBetween } from '@blocksuite/affine/std/gfx'; @@ -253,6 +254,40 @@ test('blocks should rerender when their z-index changed', async () => { assertBlocksContent(); }); +test('block host z-index should update after reordering', async () => { + const backId = addNote(doc); + const frontId = addNote(doc); + + await wait(); + + const getBlockHost = (id: string) => + document.querySelector( + `affine-edgeless-root gfx-viewport > [data-block-id="${id}"]` + ); + + const backHost = getBlockHost(backId); + const frontHost = getBlockHost(frontId); + + expect(backHost).not.toBeNull(); + expect(frontHost).not.toBeNull(); + expect(Number(backHost!.style.zIndex)).toBeLessThan( + Number(frontHost!.style.zIndex) + ); + + service.crud.updateElement(backId, { + index: service.layer.getReorderedIndex( + service.crud.getElementById(backId)!, + 'front' + ), + }); + + await wait(); + + expect(Number(backHost!.style.zIndex)).toBeGreaterThan( + Number(frontHost!.style.zIndex) + ); +}); + describe('layer reorder functionality', () => { let ids: string[] = []; @@ -428,14 +463,17 @@ describe('group related functionality', () => { const elements = [ service.crud.addElement('shape', { shapeType: 'rect', + xywh: '[0,0,100,100]', })!, addNote(doc), service.crud.addElement('shape', { shapeType: 'rect', + xywh: '[120,0,100,100]', })!, addNote(doc), service.crud.addElement('shape', { shapeType: 'rect', + xywh: '[240,0,100,100]', })!, ]; @@ -528,6 +566,35 @@ describe('group related functionality', () => { expect(service.layer.layers[1].elements[0]).toBe(group); }); + test("change mindmap index should update its nodes' layer", async () => { + const noteId = addNote(doc); + const mindmapId = service.crud.addElement('mindmap', { + children: { + text: 'root', + children: [{ text: 'child' }], + }, + })!; + + await wait(); + + const note = service.crud.getElementById(noteId)!; + const mindmap = service.crud.getElementById( + mindmapId + )! as MindmapElementModel; + const root = mindmap.tree.element; + + expect(service.layer.getZIndex(root)).toBeGreaterThan( + service.layer.getZIndex(note) + ); + + mindmap.index = service.layer.getReorderedIndex(mindmap, 'back'); + await wait(); + + expect(service.layer.getZIndex(root)).toBeLessThan( + service.layer.getZIndex(note) + ); + }); + test('should keep relative index order of elements after group, ungroup, undo, redo', () => { const edgeless = getDocRootBlock(doc, editor, 'edgeless'); const elementIds = [ @@ -769,6 +836,7 @@ test('indexed canvas should be inserted into edgeless portal when switch to edge service.crud.addElement('shape', { shapeType: 'rect', + xywh: '[0,0,100,100]', })!; addNote(doc); @@ -777,6 +845,7 @@ test('indexed canvas should be inserted into edgeless portal when switch to edge service.crud.addElement('shape', { shapeType: 'rect', + xywh: '[120,0,100,100]', })!; editor.mode = 'page'; @@ -792,10 +861,10 @@ test('indexed canvas should be inserted into edgeless portal when switch to edge '.indexable-canvas' )[0] as HTMLCanvasElement; - expect(indexedCanvas.width).toBe( + expect(indexedCanvas.width).toBeLessThanOrEqual( (surface.renderer as CanvasRenderer).canvas.width ); - expect(indexedCanvas.height).toBe( + expect(indexedCanvas.height).toBeLessThanOrEqual( (surface.renderer as CanvasRenderer).canvas.height ); expect(indexedCanvas.width).not.toBe(0); diff --git a/blocksuite/integration-test/vitest.config.ts b/blocksuite/integration-test/vitest.config.ts index f4ad799ce6..fea1d8aec6 100644 --- a/blocksuite/integration-test/vitest.config.ts +++ b/blocksuite/integration-test/vitest.config.ts @@ -1,4 +1,5 @@ import { vanillaExtractPlugin } from '@vanilla-extract/vite-plugin'; +import { playwright } from '@vitest/browser-playwright'; import { defineConfig } from 'vitest/config'; export default defineConfig(_configEnv => @@ -18,13 +19,13 @@ export default defineConfig(_configEnv => retry: process.env.CI === 'true' ? 3 : 0, browser: { enabled: true, - headless: process.env.CI === 'true', + headless: true, instances: [ { browser: 'chromium' }, { browser: 'firefox' }, { browser: 'webkit' }, ], - provider: 'playwright', + provider: playwright(), isolate: false, viewport: { width: 1024, @@ -32,16 +33,13 @@ export default defineConfig(_configEnv => }, }, coverage: { - provider: 'istanbul', // or 'c8' + provider: 'istanbul', reporter: ['lcov'], reportsDirectory: '../../.coverage/integration-test', }, deps: { interopDefault: true, }, - testTransformMode: { - web: ['src/__tests__/**/*.spec.ts'], - }, }, }) ); diff --git a/package.json b/package.json index 71ffb21f79..5a5b8372a4 100644 --- a/package.json +++ b/package.json @@ -22,7 +22,7 @@ "af": "r affine.ts", "dev": "yarn affine dev", "build": "yarn affine build", - "lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=8192\" eslint --report-unused-disable-directives-severity=off . --cache", + "lint:eslint": "cross-env NODE_OPTIONS=\"--max-old-space-size=16384\" eslint --report-unused-disable-directives-severity=off . --cache", "lint:eslint:fix": "yarn lint:eslint --fix --fix-type problem,suggestion,layout", "lint:prettier": "prettier --ignore-unknown --cache --check .", "lint:prettier:fix": "prettier --ignore-unknown --cache --write .", @@ -56,7 +56,7 @@ "@faker-js/faker": "^10.1.0", "@istanbuljs/schema": "^0.1.3", "@magic-works/i18n-codegen": "^0.6.1", - "@playwright/test": "=1.52.0", + "@playwright/test": "=1.58.2", "@smarttools/eslint-plugin-rxjs": "^1.0.8", "@taplo/cli": "^0.7.0", "@toeverything/infra": "workspace:*", @@ -64,9 +64,9 @@ "@types/node": "^22.0.0", "@typescript-eslint/parser": "^8.55.0", "@vanilla-extract/vite-plugin": "^5.0.0", - "@vitest/browser": "^3.2.4", - "@vitest/coverage-istanbul": "^3.2.4", - "@vitest/ui": "^3.2.4", + "@vitest/browser": "^4.0.18", + "@vitest/coverage-istanbul": "^4.0.18", + "@vitest/ui": "^4.0.18", "cross-env": "^10.1.0", "electron": "^39.0.0", "eslint": "^9.39.2", @@ -90,7 +90,7 @@ "typescript-eslint": "^8.55.0", "unplugin-swc": "^1.5.9", "vite": "^7.2.7", - "vitest": "^3.2.4" + "vitest": "^4.0.18" }, "packageManager": "yarn@4.12.0", "resolutions": { diff --git a/packages/backend/native/Cargo.toml b/packages/backend/native/Cargo.toml index faecf5eebe..8309dc13a9 100644 --- a/packages/backend/native/Cargo.toml +++ b/packages/backend/native/Cargo.toml @@ -14,13 +14,23 @@ affine_common = { workspace = true, features = [ "napi", "ydoc-loader", ] } +anyhow = { workspace = true } chrono = { workspace = true } file-format = { workspace = true } +image = { workspace = true } infer = { workspace = true } +libwebp-sys = { workspace = true } +little_exif = { workspace = true } +llm_adapter = { workspace = true, default-features = false, features = [ + "ureq-client", +] } +matroska = { workspace = true } mp4parse = { workspace = true } napi = { workspace = true, features = ["async"] } napi-derive = { workspace = true } rand = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } sha3 = { workspace = true } tiktoken-rs = { workspace = true } v_htmlescape = { workspace = true } diff --git a/packages/backend/native/fixtures/audio-only.mka b/packages/backend/native/fixtures/audio-only.mka new file mode 100644 index 0000000000..0d7e50b25b Binary files /dev/null and b/packages/backend/native/fixtures/audio-only.mka differ diff --git a/packages/backend/native/fixtures/audio-only.webm b/packages/backend/native/fixtures/audio-only.webm new file mode 100644 index 0000000000..92cfe023b1 Binary files /dev/null and b/packages/backend/native/fixtures/audio-only.webm differ diff --git a/packages/backend/native/fixtures/audio-video.webm b/packages/backend/native/fixtures/audio-video.webm new file mode 100644 index 0000000000..b17c5720fa Binary files /dev/null and b/packages/backend/native/fixtures/audio-video.webm differ diff --git a/packages/backend/native/index.d.ts b/packages/backend/native/index.d.ts index 825cf5dbd0..7fe7ecf820 100644 --- a/packages/backend/native/index.d.ts +++ b/packages/backend/native/index.d.ts @@ -1,5 +1,9 @@ /* auto-generated by NAPI-RS */ /* eslint-disable */ +export declare class LlmStreamHandle { + abort(): void +} + export declare class Tokenizer { count(content: string, allowedSpecial?: Array | undefined | null): number } @@ -46,6 +50,16 @@ export declare function getMime(input: Uint8Array): string export declare function htmlSanitize(input: string): string +export declare function llmDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + +export declare function llmDispatchStream(protocol: string, backendConfigJson: string, requestJson: string, callback: ((err: Error | null, arg: string) => void)): LlmStreamHandle + +export declare function llmEmbeddingDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + +export declare function llmRerankDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + +export declare function llmStructuredDispatch(protocol: string, backendConfigJson: string, requestJson: string): string + /** * Merge updates in form like `Y.applyUpdate(doc, update)` way and return the * result binary. @@ -75,6 +89,8 @@ export interface NativeCrawlResult { export interface NativeMarkdownResult { title: string markdown: string + knownUnsupportedBlocks: Array + unknownBlocks: Array } export interface NativePageDocContent { @@ -102,6 +118,8 @@ export declare function parsePageDoc(docBin: Buffer, maxSummaryLength?: number | export declare function parseWorkspaceDoc(docBin: Buffer): NativeWorkspaceDocContent | null +export declare function processImage(input: Buffer, maxEdge: number, keepExif: boolean): Promise + export declare function readAllDocIdsFromRootDoc(docBin: Buffer, includeTrash?: boolean | undefined | null): Array /** diff --git a/packages/backend/native/src/doc.rs b/packages/backend/native/src/doc.rs index d11ba85d3a..743f25ae60 100644 --- a/packages/backend/native/src/doc.rs +++ b/packages/backend/native/src/doc.rs @@ -9,6 +9,8 @@ use napi_derive::napi; pub struct NativeMarkdownResult { pub title: String, pub markdown: String, + pub known_unsupported_blocks: Vec, + pub unknown_blocks: Vec, } impl From for NativeMarkdownResult { @@ -16,6 +18,8 @@ impl From for NativeMarkdownResult { Self { title: result.title, markdown: result.markdown, + known_unsupported_blocks: result.known_unsupported_blocks, + unknown_blocks: result.unknown_blocks, } } } diff --git a/packages/backend/native/src/file_type.rs b/packages/backend/native/src/file_type.rs index 4f87b2412c..dc43ce1ed9 100644 --- a/packages/backend/native/src/file_type.rs +++ b/packages/backend/native/src/file_type.rs @@ -1,3 +1,4 @@ +use matroska::Matroska; use mp4parse::{TrackType, read_mp4}; use napi_derive::napi; @@ -8,7 +9,13 @@ pub fn get_mime(input: &[u8]) -> String { } else { file_format::FileFormat::from_bytes(input).media_type().to_string() }; - if mimetype == "video/mp4" { + if let Some(container) = matroska_container_kind(input).or(match mimetype.as_str() { + "video/webm" | "application/webm" => Some(ContainerKind::WebM), + "video/x-matroska" | "application/x-matroska" => Some(ContainerKind::Matroska), + _ => None, + }) { + detect_matroska_flavor(input, container, &mimetype) + } else if mimetype == "video/mp4" { detect_mp4_flavor(input) } else { mimetype @@ -37,3 +44,68 @@ fn detect_mp4_flavor(input: &[u8]) -> String { Err(_) => "video/mp4".to_string(), } } + +#[derive(Clone, Copy)] +enum ContainerKind { + WebM, + Matroska, +} + +impl ContainerKind { + fn audio_mime(&self) -> &'static str { + match self { + ContainerKind::WebM => "audio/webm", + ContainerKind::Matroska => "audio/x-matroska", + } + } +} + +fn detect_matroska_flavor(input: &[u8], container: ContainerKind, fallback: &str) -> String { + match Matroska::open(std::io::Cursor::new(input)) { + Ok(file) => { + let has_video = file.video_tracks().next().is_some(); + let has_audio = file.audio_tracks().next().is_some(); + if !has_video && has_audio { + container.audio_mime().to_string() + } else { + fallback.to_string() + } + } + Err(_) => fallback.to_string(), + } +} + +fn matroska_container_kind(input: &[u8]) -> Option { + let header = &input[..1024.min(input.len())]; + if header.windows(4).any(|window| window.eq_ignore_ascii_case(b"webm")) { + Some(ContainerKind::WebM) + } else if header.windows(8).any(|window| window.eq_ignore_ascii_case(b"matroska")) { + Some(ContainerKind::Matroska) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const AUDIO_ONLY_WEBM: &[u8] = include_bytes!("../fixtures/audio-only.webm"); + const AUDIO_VIDEO_WEBM: &[u8] = include_bytes!("../fixtures/audio-video.webm"); + const AUDIO_ONLY_MATROSKA: &[u8] = include_bytes!("../fixtures/audio-only.mka"); + + #[test] + fn detects_audio_only_webm_as_audio() { + assert_eq!(get_mime(AUDIO_ONLY_WEBM), "audio/webm"); + } + + #[test] + fn preserves_video_webm() { + assert_eq!(get_mime(AUDIO_VIDEO_WEBM), "video/webm"); + } + + #[test] + fn detects_audio_only_matroska_as_audio() { + assert_eq!(get_mime(AUDIO_ONLY_MATROSKA), "audio/x-matroska"); + } +} diff --git a/packages/backend/native/src/image.rs b/packages/backend/native/src/image.rs new file mode 100644 index 0000000000..5909a62f42 --- /dev/null +++ b/packages/backend/native/src/image.rs @@ -0,0 +1,353 @@ +use std::io::Cursor; + +use anyhow::{Context, Result as AnyResult, bail}; +use image::{ + AnimationDecoder, DynamicImage, ImageDecoder, ImageFormat, ImageReader, + codecs::{gif::GifDecoder, png::PngDecoder, webp::WebPDecoder}, + imageops::FilterType, + metadata::Orientation, +}; +use libwebp_sys::{ + WEBP_MUX_ABI_VERSION, WebPData, WebPDataClear, WebPDataInit, WebPEncodeRGBA, WebPFree, WebPMuxAssemble, + WebPMuxCreateInternal, WebPMuxDelete, WebPMuxError, WebPMuxSetChunk, +}; +use little_exif::{exif_tag::ExifTag, filetype::FileExtension, metadata::Metadata}; +use napi::{ + Env, Error, Result, Status, Task, + bindgen_prelude::{AsyncTask, Buffer}, +}; +use napi_derive::napi; + +const WEBP_QUALITY: f32 = 80.0; +const MAX_IMAGE_DIMENSION: u32 = 16_384; +const MAX_IMAGE_PIXELS: u64 = 40_000_000; + +pub struct AsyncProcessImageTask { + input: Vec, + max_edge: u32, + keep_exif: bool, +} + +#[napi] +impl Task for AsyncProcessImageTask { + type Output = Vec; + type JsValue = Buffer; + + fn compute(&mut self) -> Result { + process_image_inner(&self.input, self.max_edge, self.keep_exif) + .map_err(|error| Error::new(Status::InvalidArg, error.to_string())) + } + + fn resolve(&mut self, _: Env, output: Self::Output) -> Result { + Ok(output.into()) + } +} + +#[napi] +pub fn process_image(input: Buffer, max_edge: u32, keep_exif: bool) -> AsyncTask { + AsyncTask::new(AsyncProcessImageTask { + input: input.to_vec(), + max_edge, + keep_exif, + }) +} + +fn process_image_inner(input: &[u8], max_edge: u32, keep_exif: bool) -> AnyResult> { + if max_edge == 0 { + bail!("max_edge must be greater than 0"); + } + + let format = image::guess_format(input).context("unsupported image format")?; + let (width, height) = read_dimensions(input, format)?; + validate_dimensions(width, height)?; + let mut image = decode_image(input, format)?; + let orientation = read_orientation(input, format)?; + image.apply_orientation(orientation); + + if image.width().max(image.height()) > max_edge { + image = image.resize(max_edge, max_edge, FilterType::Lanczos3); + } + + let mut output = encode_webp_lossy(&image.into_rgba8())?; + + if keep_exif { + preserve_exif(input, format, &mut output)?; + } + + Ok(output) +} + +fn read_dimensions(input: &[u8], format: ImageFormat) -> AnyResult<(u32, u32)> { + ImageReader::with_format(Cursor::new(input), format) + .into_dimensions() + .context("failed to decode image") +} + +fn validate_dimensions(width: u32, height: u32) -> AnyResult<()> { + if width == 0 || height == 0 { + bail!("failed to decode image"); + } + + if width > MAX_IMAGE_DIMENSION || height > MAX_IMAGE_DIMENSION { + bail!("image dimensions exceed limit"); + } + + if u64::from(width) * u64::from(height) > MAX_IMAGE_PIXELS { + bail!("image pixel count exceeds limit"); + } + + Ok(()) +} + +fn decode_image(input: &[u8], format: ImageFormat) -> AnyResult { + Ok(match format { + ImageFormat::Gif => { + let decoder = GifDecoder::new(Cursor::new(input)).context("failed to decode image")?; + let frame = decoder + .into_frames() + .next() + .transpose() + .context("failed to decode image")? + .context("image does not contain any frames")?; + DynamicImage::ImageRgba8(frame.into_buffer()) + } + ImageFormat::Png => { + let decoder = PngDecoder::new(Cursor::new(input)).context("failed to decode image")?; + if decoder.is_apng().context("failed to decode image")? { + let frame = decoder + .apng() + .context("failed to decode image")? + .into_frames() + .next() + .transpose() + .context("failed to decode image")? + .context("image does not contain any frames")?; + DynamicImage::ImageRgba8(frame.into_buffer()) + } else { + DynamicImage::from_decoder(decoder).context("failed to decode image")? + } + } + ImageFormat::WebP => { + let decoder = WebPDecoder::new(Cursor::new(input)).context("failed to decode image")?; + let frame = decoder + .into_frames() + .next() + .transpose() + .context("failed to decode image")? + .context("image does not contain any frames")?; + DynamicImage::ImageRgba8(frame.into_buffer()) + } + _ => { + let reader = ImageReader::with_format(Cursor::new(input), format); + let decoder = reader.into_decoder().context("failed to decode image")?; + DynamicImage::from_decoder(decoder).context("failed to decode image")? + } + }) +} + +fn read_orientation(input: &[u8], format: ImageFormat) -> AnyResult { + Ok(match format { + ImageFormat::Gif => GifDecoder::new(Cursor::new(input)) + .context("failed to decode image")? + .orientation() + .context("failed to decode image")?, + ImageFormat::Png => PngDecoder::new(Cursor::new(input)) + .context("failed to decode image")? + .orientation() + .context("failed to decode image")?, + ImageFormat::WebP => WebPDecoder::new(Cursor::new(input)) + .context("failed to decode image")? + .orientation() + .context("failed to decode image")?, + _ => ImageReader::with_format(Cursor::new(input), format) + .into_decoder() + .context("failed to decode image")? + .orientation() + .context("failed to decode image")?, + }) +} + +fn encode_webp_lossy(image: &image::RgbaImage) -> AnyResult> { + let width = i32::try_from(image.width()).context("image width is too large")?; + let height = i32::try_from(image.height()).context("image height is too large")?; + let stride = width.checked_mul(4).context("image width is too large")?; + + let mut output = std::ptr::null_mut(); + let encoded_len = unsafe { WebPEncodeRGBA(image.as_ptr(), width, height, stride, WEBP_QUALITY, &mut output) }; + + if output.is_null() || encoded_len == 0 { + bail!("failed to encode webp"); + } + + let encoded = unsafe { std::slice::from_raw_parts(output, encoded_len) }.to_vec(); + unsafe { + WebPFree(output.cast()); + } + + Ok(encoded) +} + +fn preserve_exif(input: &[u8], format: ImageFormat, output: &mut Vec) -> AnyResult<()> { + let Some(file_type) = map_exif_file_type(format) else { + return Ok(()); + }; + + let input = input.to_vec(); + let Ok(mut metadata) = Metadata::new_from_vec(&input, file_type) else { + return Ok(()); + }; + + metadata.remove_tag(ExifTag::Orientation(vec![1])); + + if !metadata.get_ifds().iter().any(|ifd| !ifd.get_tags().is_empty()) { + return Ok(()); + } + + let encoded_metadata = metadata.encode().context("failed to preserve exif metadata")?; + let source = WebPData { + bytes: output.as_ptr(), + size: output.len(), + }; + let exif = WebPData { + bytes: encoded_metadata.as_ptr(), + size: encoded_metadata.len(), + }; + let mut assembled = WebPData::default(); + let mux = unsafe { WebPMuxCreateInternal(&source, 1, WEBP_MUX_ABI_VERSION as _) }; + if mux.is_null() { + bail!("failed to preserve exif metadata"); + } + + let encoded = (|| -> AnyResult> { + if unsafe { WebPMuxSetChunk(mux, c"EXIF".as_ptr(), &exif, 1) } != WebPMuxError::WEBP_MUX_OK { + bail!("failed to preserve exif metadata"); + } + + WebPDataInit(&mut assembled); + + if unsafe { WebPMuxAssemble(mux, &mut assembled) } != WebPMuxError::WEBP_MUX_OK { + bail!("failed to preserve exif metadata"); + } + + Ok(unsafe { std::slice::from_raw_parts(assembled.bytes, assembled.size) }.to_vec()) + })(); + + unsafe { + WebPDataClear(&mut assembled); + WebPMuxDelete(mux); + } + + *output = encoded?; + + Ok(()) +} + +fn map_exif_file_type(format: ImageFormat) -> Option { + match format { + ImageFormat::Jpeg => Some(FileExtension::JPEG), + ImageFormat::Png => Some(FileExtension::PNG { as_zTXt_chunk: true }), + ImageFormat::Tiff => Some(FileExtension::TIFF), + ImageFormat::WebP => Some(FileExtension::WEBP), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use image::{ExtendedColorType, GenericImageView, ImageEncoder, codecs::png::PngEncoder}; + + use super::*; + + fn encode_png(width: u32, height: u32) -> Vec { + let image = image::RgbaImage::from_pixel(width, height, image::Rgba([255, 0, 0, 255])); + let mut encoded = Vec::new(); + PngEncoder::new(&mut encoded) + .write_image(image.as_raw(), width, height, ExtendedColorType::Rgba8) + .unwrap(); + encoded + } + + fn encode_bmp_header(width: u32, height: u32) -> Vec { + let mut encoded = Vec::with_capacity(54); + encoded.extend_from_slice(b"BM"); + encoded.extend_from_slice(&(54u32).to_le_bytes()); + encoded.extend_from_slice(&0u16.to_le_bytes()); + encoded.extend_from_slice(&0u16.to_le_bytes()); + encoded.extend_from_slice(&(54u32).to_le_bytes()); + encoded.extend_from_slice(&(40u32).to_le_bytes()); + encoded.extend_from_slice(&(width as i32).to_le_bytes()); + encoded.extend_from_slice(&(height as i32).to_le_bytes()); + encoded.extend_from_slice(&1u16.to_le_bytes()); + encoded.extend_from_slice(&24u16.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded.extend_from_slice(&0u32.to_le_bytes()); + encoded + } + + #[test] + fn process_image_keeps_small_dimensions() { + let png = encode_png(8, 6); + let output = process_image_inner(&png, 512, false).unwrap(); + + let format = image::guess_format(&output).unwrap(); + assert_eq!(format, ImageFormat::WebP); + + let decoded = image::load_from_memory(&output).unwrap(); + assert_eq!(decoded.dimensions(), (8, 6)); + } + + #[test] + fn process_image_scales_down_large_dimensions() { + let png = encode_png(1024, 256); + let output = process_image_inner(&png, 512, false).unwrap(); + let decoded = image::load_from_memory(&output).unwrap(); + + assert_eq!(decoded.dimensions(), (512, 128)); + } + + #[test] + fn process_image_preserves_exif_without_orientation() { + let png = encode_png(8, 8); + let mut png_with_exif = png.clone(); + let mut metadata = Metadata::new(); + metadata.set_tag(ExifTag::ImageDescription("copilot".to_string())); + metadata.set_tag(ExifTag::Orientation(vec![6])); + metadata + .write_to_vec(&mut png_with_exif, FileExtension::PNG { as_zTXt_chunk: true }) + .unwrap(); + + let output = process_image_inner(&png_with_exif, 512, true).unwrap(); + let decoded_metadata = Metadata::new_from_vec(&output, FileExtension::WEBP).unwrap(); + + assert!( + decoded_metadata + .get_tag(&ExifTag::ImageDescription(String::new())) + .next() + .is_some() + ); + assert!( + decoded_metadata + .get_tag(&ExifTag::Orientation(vec![1])) + .next() + .is_none() + ); + } + + #[test] + fn process_image_rejects_invalid_input() { + let error = process_image_inner(b"not-an-image", 512, false).unwrap_err(); + assert_eq!(error.to_string(), "unsupported image format"); + } + + #[test] + fn process_image_rejects_images_over_dimension_limit_before_decode() { + let bmp = encode_bmp_header(MAX_IMAGE_DIMENSION + 1, 1); + let error = process_image_inner(&bmp, 512, false).unwrap_err(); + + assert_eq!(error.to_string(), "image dimensions exceed limit"); + } +} diff --git a/packages/backend/native/src/lib.rs b/packages/backend/native/src/lib.rs index 253c492858..988ef8fdb8 100644 --- a/packages/backend/native/src/lib.rs +++ b/packages/backend/native/src/lib.rs @@ -7,6 +7,8 @@ pub mod doc_loader; pub mod file_type; pub mod hashcash; pub mod html_sanitize; +pub mod image; +pub mod llm; pub mod tiktoken; use affine_common::napi_utils::map_napi_err; diff --git a/packages/backend/native/src/llm.rs b/packages/backend/native/src/llm.rs new file mode 100644 index 0000000000..26ef80aed2 --- /dev/null +++ b/packages/backend/native/src/llm.rs @@ -0,0 +1,414 @@ +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; + +use llm_adapter::{ + backend::{ + BackendConfig, BackendError, BackendProtocol, DefaultHttpClient, dispatch_embedding_request, dispatch_request, + dispatch_rerank_request, dispatch_stream_events_with, dispatch_structured_request, + }, + core::{CoreRequest, EmbeddingRequest, RerankRequest, StreamEvent, StructuredRequest}, + middleware::{ + MiddlewareConfig, PipelineContext, RequestMiddleware, StreamMiddleware, citation_indexing, clamp_max_tokens, + normalize_messages, run_request_middleware_chain, run_stream_middleware_chain, stream_event_normalize, + tool_schema_rewrite, + }, +}; +use napi::{ + Error, Result, Status, + threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}, +}; +use serde::Deserialize; + +pub const STREAM_END_MARKER: &str = "__AFFINE_LLM_STREAM_END__"; +const STREAM_ABORTED_REASON: &str = "__AFFINE_LLM_STREAM_ABORTED__"; +const STREAM_CALLBACK_DISPATCH_FAILED_REASON: &str = "__AFFINE_LLM_STREAM_CALLBACK_DISPATCH_FAILED__"; + +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default)] +struct LlmMiddlewarePayload { + request: Vec, + stream: Vec, + config: MiddlewareConfig, +} + +#[derive(Debug, Clone, Deserialize)] +struct LlmDispatchPayload { + #[serde(flatten)] + request: CoreRequest, + #[serde(default)] + middleware: LlmMiddlewarePayload, +} + +#[derive(Debug, Clone, Deserialize)] +struct LlmStructuredDispatchPayload { + #[serde(flatten)] + request: StructuredRequest, + #[serde(default)] + middleware: LlmMiddlewarePayload, +} + +#[derive(Debug, Clone, Deserialize)] +struct LlmRerankDispatchPayload { + #[serde(flatten)] + request: RerankRequest, +} + +#[napi] +pub struct LlmStreamHandle { + aborted: Arc, +} + +#[napi] +impl LlmStreamHandle { + #[napi] + pub fn abort(&self) { + self.aborted.store(true, Ordering::SeqCst); + } +} + +#[napi(catch_unwind)] +pub fn llm_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + let request = apply_request_middlewares(payload.request, &payload.middleware)?; + + let response = + dispatch_request(&DefaultHttpClient::default(), &config, protocol, &request).map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_structured_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmStructuredDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + let request = apply_structured_request_middlewares(payload.request, &payload.middleware)?; + + let response = dispatch_structured_request(&DefaultHttpClient::default(), &config, protocol, &request) + .map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_embedding_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let request: EmbeddingRequest = serde_json::from_str(&request_json).map_err(map_json_error)?; + + let response = dispatch_embedding_request(&DefaultHttpClient::default(), &config, protocol, &request) + .map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_rerank_dispatch(protocol: String, backend_config_json: String, request_json: String) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmRerankDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + + let response = dispatch_rerank_request(&DefaultHttpClient::default(), &config, protocol, &payload.request) + .map_err(map_backend_error)?; + + serde_json::to_string(&response).map_err(map_json_error) +} + +#[napi(catch_unwind)] +pub fn llm_dispatch_stream( + protocol: String, + backend_config_json: String, + request_json: String, + callback: ThreadsafeFunction, +) -> Result { + let protocol = parse_protocol(&protocol)?; + let config: BackendConfig = serde_json::from_str(&backend_config_json).map_err(map_json_error)?; + let payload: LlmDispatchPayload = serde_json::from_str(&request_json).map_err(map_json_error)?; + let request = apply_request_middlewares(payload.request, &payload.middleware)?; + let middleware = payload.middleware.clone(); + + let aborted = Arc::new(AtomicBool::new(false)); + let aborted_in_worker = aborted.clone(); + + std::thread::spawn(move || { + let chain = match resolve_stream_chain(&middleware.stream) { + Ok(chain) => chain, + Err(error) => { + emit_error_event(&callback, error.reason.clone(), "middleware_error"); + let _ = callback.call( + Ok(STREAM_END_MARKER.to_string()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + return; + } + }; + let mut pipeline = StreamPipeline::new(chain, middleware.config.clone()); + let mut aborted_by_user = false; + let mut callback_dispatch_failed = false; + + let result = dispatch_stream_events_with(&DefaultHttpClient::default(), &config, protocol, &request, |event| { + if aborted_in_worker.load(Ordering::Relaxed) { + aborted_by_user = true; + return Err(BackendError::Http(STREAM_ABORTED_REASON.to_string())); + } + + for event in pipeline.process(event) { + let status = emit_stream_event(&callback, &event); + if status != Status::Ok { + callback_dispatch_failed = true; + return Err(BackendError::Http(format!( + "{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:{status}" + ))); + } + } + + Ok(()) + }); + + if !aborted_by_user { + for event in pipeline.finish() { + if aborted_in_worker.load(Ordering::Relaxed) { + aborted_by_user = true; + break; + } + if emit_stream_event(&callback, &event) != Status::Ok { + callback_dispatch_failed = true; + break; + } + } + } + + if let Err(error) = result + && !aborted_by_user + && !callback_dispatch_failed + && !is_abort_error(&error) + && !is_callback_dispatch_failed_error(&error) + { + emit_error_event(&callback, error.to_string(), "dispatch_error"); + } + + if !callback_dispatch_failed { + let _ = callback.call( + Ok(STREAM_END_MARKER.to_string()), + ThreadsafeFunctionCallMode::NonBlocking, + ); + } + }); + + Ok(LlmStreamHandle { aborted }) +} + +fn apply_request_middlewares(request: CoreRequest, middleware: &LlmMiddlewarePayload) -> Result { + let chain = resolve_request_chain(&middleware.request)?; + Ok(run_request_middleware_chain(request, &middleware.config, &chain)) +} + +fn apply_structured_request_middlewares( + request: StructuredRequest, + middleware: &LlmMiddlewarePayload, +) -> Result { + let mut core = request.as_core_request(); + core = apply_request_middlewares(core, middleware)?; + + Ok(StructuredRequest { + model: core.model, + messages: core.messages, + schema: core + .response_schema + .ok_or_else(|| Error::new(Status::InvalidArg, "Structured request schema is required"))?, + max_tokens: core.max_tokens, + temperature: core.temperature, + reasoning: core.reasoning, + strict: request.strict, + response_mime_type: request.response_mime_type, + }) +} + +#[derive(Clone)] +struct StreamPipeline { + chain: Vec, + config: MiddlewareConfig, + context: PipelineContext, +} + +impl StreamPipeline { + fn new(chain: Vec, config: MiddlewareConfig) -> Self { + Self { + chain, + config, + context: PipelineContext::default(), + } + } + + fn process(&mut self, event: StreamEvent) -> Vec { + run_stream_middleware_chain(event, &mut self.context, &self.config, &self.chain) + } + + fn finish(&mut self) -> Vec { + self.context.flush_pending_deltas(); + self.context.drain_queued_events() + } +} + +fn emit_stream_event(callback: &ThreadsafeFunction, event: &StreamEvent) -> Status { + let value = serde_json::to_string(event).unwrap_or_else(|error| { + serde_json::json!({ + "type": "error", + "message": format!("failed to serialize stream event: {error}"), + }) + .to_string() + }); + + callback.call(Ok(value), ThreadsafeFunctionCallMode::NonBlocking) +} + +fn emit_error_event(callback: &ThreadsafeFunction, message: String, code: &str) { + let error_event = serde_json::to_string(&StreamEvent::Error { + message: message.clone(), + code: Some(code.to_string()), + }) + .unwrap_or_else(|_| { + serde_json::json!({ + "type": "error", + "message": message, + "code": code, + }) + .to_string() + }); + + let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking); +} + +fn is_abort_error(error: &BackendError) -> bool { + matches!( + error, + BackendError::Http(reason) if reason == STREAM_ABORTED_REASON + ) +} + +fn is_callback_dispatch_failed_error(error: &BackendError) -> bool { + matches!( + error, + BackendError::Http(reason) if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON) + ) +} + +fn resolve_request_chain(request: &[String]) -> Result> { + if request.is_empty() { + return Ok(vec![normalize_messages, tool_schema_rewrite]); + } + + request + .iter() + .map(|name| match name.as_str() { + "normalize_messages" => Ok(normalize_messages as RequestMiddleware), + "clamp_max_tokens" => Ok(clamp_max_tokens as RequestMiddleware), + "tool_schema_rewrite" => Ok(tool_schema_rewrite as RequestMiddleware), + _ => Err(Error::new( + Status::InvalidArg, + format!("Unsupported request middleware: {name}"), + )), + }) + .collect() +} + +fn resolve_stream_chain(stream: &[String]) -> Result> { + if stream.is_empty() { + return Ok(vec![stream_event_normalize, citation_indexing]); + } + + stream + .iter() + .map(|name| match name.as_str() { + "stream_event_normalize" => Ok(stream_event_normalize as StreamMiddleware), + "citation_indexing" => Ok(citation_indexing as StreamMiddleware), + _ => Err(Error::new( + Status::InvalidArg, + format!("Unsupported stream middleware: {name}"), + )), + }) + .collect() +} + +fn parse_protocol(protocol: &str) -> Result { + match protocol { + "openai_chat" | "openai-chat" | "openai_chat_completions" | "chat-completions" | "chat_completions" => { + Ok(BackendProtocol::OpenaiChatCompletions) + } + "openai_responses" | "openai-responses" | "responses" => Ok(BackendProtocol::OpenaiResponses), + "anthropic" | "anthropic_messages" | "anthropic-messages" => Ok(BackendProtocol::AnthropicMessages), + "gemini" | "gemini_generate_content" | "gemini-generate-content" => Ok(BackendProtocol::GeminiGenerateContent), + other => Err(Error::new( + Status::InvalidArg, + format!("Unsupported llm backend protocol: {other}"), + )), + } +} + +fn map_json_error(error: serde_json::Error) -> Error { + Error::new(Status::InvalidArg, format!("Invalid JSON payload: {error}")) +} + +fn map_backend_error(error: BackendError) -> Error { + Error::new(Status::GenericFailure, error.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_parse_supported_protocol_aliases() { + assert!(parse_protocol("openai_chat").is_ok()); + assert!(parse_protocol("chat-completions").is_ok()); + assert!(parse_protocol("responses").is_ok()); + assert!(parse_protocol("anthropic").is_ok()); + assert!(parse_protocol("gemini").is_ok()); + } + + #[test] + fn should_reject_unsupported_protocol() { + let error = parse_protocol("unknown").unwrap_err(); + assert_eq!(error.status, Status::InvalidArg); + assert!(error.reason.contains("Unsupported llm backend protocol")); + } + + #[test] + fn llm_dispatch_should_reject_invalid_backend_json() { + let error = llm_dispatch("openai_chat".to_string(), "{".to_string(), "{}".to_string()).unwrap_err(); + assert_eq!(error.status, Status::InvalidArg); + assert!(error.reason.contains("Invalid JSON payload")); + } + + #[test] + fn map_json_error_should_use_invalid_arg_status() { + let parse_error = serde_json::from_str::("{").unwrap_err(); + let error = map_json_error(parse_error); + assert_eq!(error.status, Status::InvalidArg); + assert!(error.reason.contains("Invalid JSON payload")); + } + + #[test] + fn resolve_request_chain_should_support_clamp_max_tokens() { + let chain = resolve_request_chain(&["normalize_messages".to_string(), "clamp_max_tokens".to_string()]).unwrap(); + assert_eq!(chain.len(), 2); + } + + #[test] + fn resolve_request_chain_should_reject_unknown_middleware() { + let error = resolve_request_chain(&["unknown".to_string()]).unwrap_err(); + assert_eq!(error.status, Status::InvalidArg); + assert!(error.reason.contains("Unsupported request middleware")); + } + + #[test] + fn resolve_stream_chain_should_reject_unknown_middleware() { + let error = resolve_stream_chain(&["unknown".to_string()]).unwrap_err(); + assert_eq!(error.status, Status::InvalidArg); + assert!(error.reason.contains("Unsupported stream middleware")); + } +} diff --git a/packages/backend/server/.env.example b/packages/backend/server/.env.example index 2d8a2362b7..6f26b561ea 100644 --- a/packages/backend/server/.env.example +++ b/packages/backend/server/.env.example @@ -6,6 +6,7 @@ # MAILER_HOST=127.0.0.1 # MAILER_PORT=1025 +# MAILER_SERVERNAME="mail.example.com" # MAILER_SENDER="noreply@toeverything.info" # MAILER_USER="noreply@toeverything.info" # MAILER_PASSWORD="affine" diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 60f9343ae5..11e83e33a2 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -4,17 +4,14 @@ "version": "0.26.3", "description": "Affine Node.js server", "type": "module", - "bin": { - "run-test": "./scripts/run-test.ts" - }, "scripts": { "build": "affine bundle -p @affine/server", "dev": "nodemon ./src/index.ts", "dev:mail": "email dev -d src/mails", "test": "ava --concurrency 1 --serial", - "test:copilot": "ava \"src/__tests__/copilot-*.spec.ts\"", + "test:copilot": "ava \"src/__tests__/copilot/copilot-*.spec.ts\"", "test:coverage": "c8 ava --concurrency 1 --serial", - "test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot-*.spec.ts\"", + "test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot/copilot-*.spec.ts\"", "e2e": "cross-env TEST_MODE=e2e ava --serial", "e2e:coverage": "cross-env TEST_MODE=e2e c8 ava --serial", "data-migration": "cross-env NODE_ENV=development SERVER_FLAVOR=script r ./src/index.ts", @@ -28,19 +25,12 @@ "dependencies": { "@affine/s3-compat": "workspace:*", "@affine/server-native": "workspace:*", - "@ai-sdk/anthropic": "^2.0.54", - "@ai-sdk/google": "^2.0.45", - "@ai-sdk/google-vertex": "^3.0.88", - "@ai-sdk/openai": "^2.0.80", - "@ai-sdk/openai-compatible": "^1.0.28", - "@ai-sdk/perplexity": "^2.0.21", "@apollo/server": "^4.13.0", "@fal-ai/serverless-client": "^0.15.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0", "@google-cloud/opentelemetry-resource-util": "^3.0.0", - "@modelcontextprotocol/sdk": "^1.26.0", - "@nestjs-cls/transactional": "^2.7.0", - "@nestjs-cls/transactional-adapter-prisma": "^1.2.24", + "@nestjs-cls/transactional": "^3.2.0", + "@nestjs-cls/transactional-adapter-prisma": "^1.3.4", "@nestjs/apollo": "^13.0.4", "@nestjs/bullmq": "^11.0.4", "@nestjs/common": "^11.0.21", @@ -55,18 +45,18 @@ "@node-rs/crc32": "^1.10.6", "@opentelemetry/api": "^1.9.0", "@opentelemetry/core": "^2.2.0", - "@opentelemetry/exporter-prometheus": "^0.211.0", + "@opentelemetry/exporter-prometheus": "^0.212.0", "@opentelemetry/exporter-zipkin": "^2.2.0", "@opentelemetry/host-metrics": "^0.38.0", - "@opentelemetry/instrumentation": "^0.211.0", - "@opentelemetry/instrumentation-graphql": "^0.58.0", - "@opentelemetry/instrumentation-http": "^0.211.0", - "@opentelemetry/instrumentation-ioredis": "^0.59.0", - "@opentelemetry/instrumentation-nestjs-core": "^0.57.0", - "@opentelemetry/instrumentation-socket.io": "^0.57.0", + "@opentelemetry/instrumentation": "^0.212.0", + "@opentelemetry/instrumentation-graphql": "^0.60.0", + "@opentelemetry/instrumentation-http": "^0.212.0", + "@opentelemetry/instrumentation-ioredis": "^0.60.0", + "@opentelemetry/instrumentation-nestjs-core": "^0.58.0", + "@opentelemetry/instrumentation-socket.io": "^0.59.0", "@opentelemetry/resources": "^2.2.0", "@opentelemetry/sdk-metrics": "^2.2.0", - "@opentelemetry/sdk-node": "^0.211.0", + "@opentelemetry/sdk-node": "^0.212.0", "@opentelemetry/sdk-trace-node": "^2.2.0", "@opentelemetry/semantic-conventions": "^1.38.0", "@prisma/client": "^6.6.0", @@ -74,7 +64,6 @@ "@queuedash/api": "^3.16.0", "@react-email/components": "^0.5.7", "@socket.io/redis-adapter": "^8.3.0", - "ai": "^5.0.118", "bullmq": "^5.40.2", "cookie-parser": "^1.4.7", "cross-env": "^10.1.0", @@ -126,7 +115,6 @@ "@faker-js/faker": "^10.1.0", "@nestjs/swagger": "^11.2.0", "@nestjs/testing": "patch:@nestjs/testing@npm%3A10.4.15#~/.yarn/patches/@nestjs-testing-npm-10.4.15-d591a1705a.patch", - "@react-email/preview-server": "^4.3.2", "@types/cookie-parser": "^1.4.8", "@types/express": "^5.0.1", "@types/express-serve-static-core": "^5.0.6", @@ -142,8 +130,8 @@ "@types/react-dom": "^19.0.2", "@types/semver": "^7.5.8", "@types/sinon": "^21.0.0", - "@types/supertest": "^6.0.2", - "ava": "^6.4.0", + "@types/supertest": "^7.0.0", + "ava": "^7.0.0", "c8": "^10.1.3", "nodemon": "^3.1.14", "react-email": "^4.3.2", diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap deleted file mode 100644 index 9239165a93..0000000000 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap and /dev/null differ diff --git a/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.md b/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.md index d59ba4a7e3..ffd25a2b5c 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.md @@ -43,7 +43,9 @@ Generated by [AVA](https://avajs.dev). > Snapshot 5 Buffer @Uint8Array [ - 66616b65 20696d61 6765 + 89504e47 0d0a1a0a 0000000d 49484452 00000001 00000001 08040000 00b51c0c + 02000000 0b494441 5478da63 fcff1f00 03030200 efa37c9f 00000000 49454e44 + ae426082 ] ## should preview link diff --git a/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.snap index 4679da1148..eefab65bc3 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/worker.e2e.ts.snap differ diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.md similarity index 95% rename from packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md rename to packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.md index 49f5c10e76..2bf58ee451 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md +++ b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.md @@ -12,12 +12,12 @@ Generated by [AVA](https://avajs.dev). { messages: [ { - content: 'generate text to text', + content: 'generate text to text stream', role: 'assistant', }, ], pinned: false, - tokens: 8, + tokens: 10, }, ] @@ -27,12 +27,12 @@ Generated by [AVA](https://avajs.dev). { messages: [ { - content: 'generate text to text', + content: 'generate text to text stream', role: 'assistant', }, ], pinned: false, - tokens: 8, + tokens: 10, }, ] diff --git a/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.snap new file mode 100644 index 0000000000..59261fcfc6 Binary files /dev/null and b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.e2e.ts.snap differ diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.spec.ts.md similarity index 100% rename from packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.md rename to packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.spec.ts.md diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap b/packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.spec.ts.snap similarity index 100% rename from packages/backend/server/src/__tests__/__snapshots__/copilot.spec.ts.snap rename to packages/backend/server/src/__tests__/copilot/__snapshots__/copilot.spec.ts.snap diff --git a/packages/backend/server/src/__tests__/copilot-provider.spec.ts b/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts similarity index 90% rename from packages/backend/server/src/__tests__/copilot-provider.spec.ts rename to packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts index 80874c8b41..903932683e 100644 --- a/packages/backend/server/src/__tests__/copilot-provider.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/copilot-provider.spec.ts @@ -4,31 +4,31 @@ import type { ExecutionContext, TestFn } from 'ava'; import ava from 'ava'; import { z } from 'zod'; -import { ServerFeature, ServerService } from '../core'; -import { AuthService } from '../core/auth'; -import { QuotaModule } from '../core/quota'; -import { Models } from '../models'; -import { CopilotModule } from '../plugins/copilot'; -import { prompts, PromptService } from '../plugins/copilot/prompt'; +import { ServerFeature, ServerService } from '../../core'; +import { AuthService } from '../../core/auth'; +import { QuotaModule } from '../../core/quota'; +import { Models } from '../../models'; +import { CopilotModule } from '../../plugins/copilot'; +import { prompts, PromptService } from '../../plugins/copilot/prompt'; import { CopilotProviderFactory, CopilotProviderType, StreamObject, StreamObjectSchema, -} from '../plugins/copilot/providers'; -import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types'; +} from '../../plugins/copilot/providers'; +import { TranscriptionResponseSchema } from '../../plugins/copilot/transcript/types'; import { CopilotChatTextExecutor, CopilotWorkflowService, GraphExecutorState, -} from '../plugins/copilot/workflow'; +} from '../../plugins/copilot/workflow'; import { CopilotChatImageExecutor, CopilotCheckHtmlExecutor, CopilotCheckJsonExecutor, -} from '../plugins/copilot/workflow/executor'; -import { createTestingModule, TestingModule } from './utils'; -import { TestAssets } from './utils/copilot'; +} from '../../plugins/copilot/workflow/executor'; +import { createTestingModule, TestingModule } from '../utils'; +import { TestAssets } from '../utils/copilot'; type Tester = { auth: AuthService; @@ -118,7 +118,6 @@ test.serial.before(async t => { enabled: true, scenarios: { image: 'flux-1/schnell', - rerank: 'gpt-5-mini', complex_text_generation: 'gpt-5-mini', coding: 'gpt-5-mini', quick_decision_making: 'gpt-5-mini', @@ -226,6 +225,20 @@ const checkStreamObjects = (result: string) => { } }; +const parseStreamObjects = (result: string): StreamObject[] => { + const streamObjects = JSON.parse(result); + return z.array(StreamObjectSchema).parse(streamObjects); +}; + +const getStreamObjectText = (result: string) => + parseStreamObjects(result) + .filter( + (chunk): chunk is Extract => + chunk.type === 'text-delta' + ) + .map(chunk => chunk.textDelta) + .join(''); + const retry = async ( action: string, t: ExecutionContext, @@ -445,6 +458,49 @@ The term **“CRDT”** was first introduced by Marc Shapiro, Nuno Preguiça, Ca }, type: 'object' as const, }, + { + name: 'Gemini native text', + promptName: ['Chat With AFFiNE AI'], + messages: [ + { + role: 'user' as const, + content: + 'In one short sentence, explain what AFFiNE AI is and mention AFFiNE by name.', + }, + ], + config: { model: 'gemini-2.5-flash' }, + verifier: (t: ExecutionContext, result: string) => { + assertNotWrappedInCodeBlock(t, result); + t.assert( + result.toLowerCase().includes('affine'), + 'should mention AFFiNE' + ); + }, + prefer: CopilotProviderType.Gemini, + type: 'text' as const, + }, + { + name: 'Gemini native stream objects', + promptName: ['Chat With AFFiNE AI'], + messages: [ + { + role: 'user' as const, + content: + 'Respond with one short sentence about AFFiNE AI and mention AFFiNE by name.', + }, + ], + config: { model: 'gemini-2.5-flash' }, + verifier: (t: ExecutionContext, result: string) => { + t.truthy(checkStreamObjects(result), 'should be valid stream objects'); + const assembledText = getStreamObjectText(result); + t.assert( + assembledText.toLowerCase().includes('affine'), + 'should mention AFFiNE' + ); + }, + prefer: CopilotProviderType.Gemini, + type: 'object' as const, + }, { name: 'Should transcribe short audio', promptName: ['Transcript audio'], @@ -717,14 +773,13 @@ for (const { const { factory, prompt: promptService } = t.context; const prompt = (await promptService.get(promptName))!; t.truthy(prompt, 'should have prompt'); - const provider = (await factory.getProviderByModel(prompt.model, { + const finalConfig = Object.assign({}, prompt.config, config); + const modelId = finalConfig.model || prompt.model; + const provider = (await factory.getProviderByModel(modelId, { prefer, }))!; t.truthy(provider, 'should have provider'); await retry(`action: ${promptName}`, t, async t => { - const finalConfig = Object.assign({}, prompt.config, config); - const modelId = finalConfig.model || prompt.model; - switch (type) { case 'text': { const result = await provider.text( @@ -892,7 +947,7 @@ test( 'should be able to rerank message chunks', runIfCopilotConfigured, async t => { - const { factory, prompt } = t.context; + const { factory } = t.context; await retry('rerank', t, async t => { const query = 'Is this content relevant to programming?'; @@ -909,14 +964,18 @@ test( 'The stock market is experiencing significant fluctuations.', ]; - const p = (await prompt.get('Rerank results'))!; - t.assert(p, 'should have prompt for rerank'); - const provider = (await factory.getProviderByModel(p.model))!; + const provider = (await factory.getProviderByModel('gpt-5.2'))!; t.assert(provider, 'should have provider for rerank'); const scores = await provider.rerank( - { modelId: p.model }, - embeddings.map(e => p.finish({ query, doc: e })) + { modelId: 'gpt-5.2' }, + { + query, + candidates: embeddings.map((text, index) => ({ + id: String(index), + text, + })), + } ); t.is(scores.length, 10, 'should return scores for all chunks'); @@ -931,8 +990,8 @@ test( t.log('Rerank scores:', scores); t.is( scores.filter(s => s > 0.5).length, - 4, - 'should have 4 related chunks' + 5, + 'should have 5 related chunks' ); }); } diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot/copilot.e2e.ts similarity index 95% rename from packages/backend/server/src/__tests__/copilot.e2e.ts rename to packages/backend/server/src/__tests__/copilot/copilot.e2e.ts index 3a8bf5aee3..82a1316a3a 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot/copilot.e2e.ts @@ -6,25 +6,26 @@ import type { TestFn } from 'ava'; import ava from 'ava'; import Sinon from 'sinon'; -import { AppModule } from '../app.module'; -import { JobQueue } from '../base'; -import { ConfigModule } from '../base/config'; -import { AuthService } from '../core/auth'; -import { DocReader } from '../core/doc'; -import { CopilotContextService } from '../plugins/copilot/context'; +import { AppModule } from '../../app.module'; +import { JobQueue } from '../../base'; +import { ConfigModule } from '../../base/config'; +import { AuthService } from '../../core/auth'; +import { DocReader } from '../../core/doc'; +import { CopilotContextService } from '../../plugins/copilot/context'; import { CopilotEmbeddingJob, MockEmbeddingClient, -} from '../plugins/copilot/embedding'; -import { prompts, PromptService } from '../plugins/copilot/prompt'; +} from '../../plugins/copilot/embedding'; +import { ChatMessageCache } from '../../plugins/copilot/message'; +import { prompts, PromptService } from '../../plugins/copilot/prompt'; import { CopilotProviderFactory, CopilotProviderType, GeminiGenerativeProvider, OpenAIProvider, -} from '../plugins/copilot/providers'; -import { CopilotStorage } from '../plugins/copilot/storage'; -import { MockCopilotProvider } from './mocks'; +} from '../../plugins/copilot/providers'; +import { CopilotStorage } from '../../plugins/copilot/storage'; +import { MockCopilotProvider } from '../mocks'; import { acceptInviteById, createTestingApp, @@ -33,7 +34,7 @@ import { smallestPng, TestingApp, TestUser, -} from './utils'; +} from '../utils'; import { addContextDoc, addContextFile, @@ -67,7 +68,7 @@ import { textToEventStream, unsplashSearch, updateCopilotSession, -} from './utils/copilot'; +} from '../utils/copilot'; const test = ava as TestFn<{ auth: AuthService; @@ -416,6 +417,7 @@ test('should be able to use test provider', async t => { test('should create message correctly', async t => { const { app } = t.context; + const messageCache = app.get(ChatMessageCache); { const { id } = await createWorkspace(app); @@ -463,6 +465,19 @@ test('should create message correctly', async t => { new File([new Uint8Array(pngData)], '1.png', { type: 'image/png' }) ); t.truthy(messageId, 'should be able to create message with blob'); + + const message = await messageCache.get(messageId); + const attachment = message?.attachments?.[0] as + | { attachment: string; mimeType: string } + | undefined; + const payload = Buffer.from( + attachment?.attachment.split(',').at(1) || '', + 'base64' + ); + + t.is(attachment?.mimeType, 'image/webp'); + t.is(payload.subarray(0, 4).toString('ascii'), 'RIFF'); + t.is(payload.subarray(8, 12).toString('ascii'), 'WEBP'); } // with attachments @@ -513,7 +528,11 @@ test('should be able to chat with api', async t => { ); const messageId = await createCopilotMessage(app, sessionId); const ret = await chatWithText(app, sessionId, messageId); - t.is(ret, 'generate text to text', 'should be able to chat with text'); + t.is( + ret, + 'generate text to text stream', + 'should be able to chat with text' + ); const ret2 = await chatWithTextStream(app, sessionId, messageId); t.is( @@ -657,7 +676,7 @@ test('should be able to retry with api', async t => { const histories = await getHistories(app, { workspaceId: id, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), - [['generate text to text', 'generate text to text']], + [['generate text to text stream', 'generate text to text stream']], 'should be able to list history' ); } @@ -794,7 +813,7 @@ test('should be able to list history', async t => { const histories = await getHistories(app, { workspaceId, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), - [['hello', 'generate text to text']], + [['hello', 'generate text to text stream']], 'should be able to list history' ); } @@ -807,7 +826,7 @@ test('should be able to list history', async t => { }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), - [['generate text to text', 'hello']], + [['generate text to text stream', 'hello']], 'should be able to list history' ); } @@ -858,7 +877,7 @@ test('should reject request that user have not permission', async t => { const histories = await getHistories(app, { workspaceId, docId }); t.deepEqual( histories.map(h => h.messages.map(m => m.content)), - [['generate text to text']], + [['generate text to text stream']], 'should able to list history' ); diff --git a/packages/backend/server/src/__tests__/copilot.spec.ts b/packages/backend/server/src/__tests__/copilot/copilot.spec.ts similarity index 89% rename from packages/backend/server/src/__tests__/copilot.spec.ts rename to packages/backend/server/src/__tests__/copilot/copilot.spec.ts index 88afae12e0..6b6ef9d22e 100644 --- a/packages/backend/server/src/__tests__/copilot.spec.ts +++ b/packages/backend/server/src/__tests__/copilot/copilot.spec.ts @@ -8,38 +8,35 @@ import ava from 'ava'; import { nanoid } from 'nanoid'; import Sinon from 'sinon'; -import { EventBus, JobQueue } from '../base'; -import { ConfigModule } from '../base/config'; -import { AuthService } from '../core/auth'; -import { QuotaModule } from '../core/quota'; -import { StorageModule, WorkspaceBlobStorage } from '../core/storage'; +import { EventBus, JobQueue } from '../../base'; +import { ConfigModule } from '../../base/config'; +import { AuthService } from '../../core/auth'; +import { QuotaModule } from '../../core/quota'; +import { StorageModule, WorkspaceBlobStorage } from '../../core/storage'; import { ContextCategories, CopilotSessionModel, WorkspaceModel, -} from '../models'; -import { CopilotModule } from '../plugins/copilot'; -import { CopilotContextService } from '../plugins/copilot/context'; -import { CopilotCronJobs } from '../plugins/copilot/cron'; +} from '../../models'; +import { CopilotModule } from '../../plugins/copilot'; +import { CopilotContextService } from '../../plugins/copilot/context'; +import { CopilotCronJobs } from '../../plugins/copilot/cron'; import { CopilotEmbeddingJob, MockEmbeddingClient, -} from '../plugins/copilot/embedding'; -import { prompts, PromptService } from '../plugins/copilot/prompt'; +} from '../../plugins/copilot/embedding'; +import { prompts, PromptService } from '../../plugins/copilot/prompt'; import { CopilotProviderFactory, CopilotProviderType, ModelInputType, ModelOutputType, OpenAIProvider, -} from '../plugins/copilot/providers'; -import { - CitationParser, - TextStreamParser, -} from '../plugins/copilot/providers/utils'; -import { ChatSessionService } from '../plugins/copilot/session'; -import { CopilotStorage } from '../plugins/copilot/storage'; -import { CopilotTranscriptionService } from '../plugins/copilot/transcript'; +} from '../../plugins/copilot/providers'; +import { TextStreamParser } from '../../plugins/copilot/providers/utils'; +import { ChatSessionService } from '../../plugins/copilot/session'; +import { CopilotStorage } from '../../plugins/copilot/storage'; +import { CopilotTranscriptionService } from '../../plugins/copilot/transcript'; import { CopilotChatTextExecutor, CopilotWorkflowService, @@ -48,7 +45,7 @@ import { WorkflowGraphExecutor, type WorkflowNodeData, WorkflowNodeType, -} from '../plugins/copilot/workflow'; +} from '../../plugins/copilot/workflow'; import { CopilotChatImageExecutor, CopilotCheckHtmlExecutor, @@ -56,16 +53,16 @@ import { getWorkflowExecutor, NodeExecuteState, NodeExecutorType, -} from '../plugins/copilot/workflow/executor'; -import { AutoRegisteredWorkflowExecutor } from '../plugins/copilot/workflow/executor/utils'; -import { WorkflowGraphList } from '../plugins/copilot/workflow/graph'; -import { CopilotWorkspaceService } from '../plugins/copilot/workspace'; -import { PaymentModule } from '../plugins/payment'; -import { SubscriptionService } from '../plugins/payment/service'; -import { SubscriptionStatus } from '../plugins/payment/types'; -import { MockCopilotProvider } from './mocks'; -import { createTestingModule, TestingModule } from './utils'; -import { WorkflowTestCases } from './utils/copilot'; +} from '../../plugins/copilot/workflow/executor'; +import { AutoRegisteredWorkflowExecutor } from '../../plugins/copilot/workflow/executor/utils'; +import { WorkflowGraphList } from '../../plugins/copilot/workflow/graph'; +import { CopilotWorkspaceService } from '../../plugins/copilot/workspace'; +import { PaymentModule } from '../../plugins/payment'; +import { SubscriptionService } from '../../plugins/payment/service'; +import { SubscriptionStatus } from '../../plugins/payment/types'; +import { MockCopilotProvider } from '../mocks'; +import { createTestingModule, TestingModule } from '../utils'; +import { WorkflowTestCases } from '../utils/copilot'; type Context = { auth: AuthService; @@ -364,6 +361,21 @@ test('should be able to manage chat session', async t => { }); t.is(newSessionId, sessionId, 'should get same session id'); } + + // should create a fresh session when reuseLatestChat is explicitly disabled + { + const newSessionId = await session.create({ + userId, + promptName, + ...commonParams, + reuseLatestChat: false, + }); + t.not( + newSessionId, + sessionId, + 'should create new session id when reuseLatestChat is false' + ); + } }); test('should be able to update chat session prompt', async t => { @@ -645,6 +657,55 @@ test('should be able to generate with message id', async t => { } }); +test('should preserve file handle attachments when merging user content into prompt', async t => { + const { prompt, session } = t.context; + + await prompt.set(promptName, 'model', [ + { role: 'user', content: '{{content}}' }, + ]); + + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName, + pinned: false, + }); + const s = (await session.get(sessionId))!; + + const message = await session.createMessage({ + sessionId, + content: 'Summarize this file', + attachments: [ + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + }, + ], + }); + + await s.pushByMessageId(message); + const finalMessages = s.finish({}); + + t.deepEqual(finalMessages, [ + { + role: 'user', + content: 'Summarize this file', + attachments: [ + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + }, + ], + params: { + content: 'Summarize this file', + }, + }, + ]); +}); + test('should save message correctly', async t => { const { prompt, session } = t.context; @@ -881,6 +942,26 @@ test('should be able to get provider', async t => { } }); +test('should resolve provider by prefixed model id', async t => { + const { factory } = t.context; + + const provider = await factory.getProviderByModel('openai-default/test'); + t.truthy(provider, 'should resolve prefixed model id'); + t.is(provider?.type, CopilotProviderType.OpenAI); + + const result = await provider?.text({ modelId: 'openai-default/test' }, [ + { role: 'user', content: 'hello' }, + ]); + t.is(result, 'generate text to text'); +}); + +test('should fallback to null when prefixed provider id does not exist', async t => { + const { factory } = t.context; + + const provider = await factory.getProviderByModel('unknown/test'); + t.is(provider, null); +}); + // ==================== workflow ==================== // this test used to preview the final result of the workflow @@ -1190,149 +1271,6 @@ test('should be able to run image executor', async t => { Sinon.restore(); }); -test('CitationParser should replace citation placeholders with URLs', t => { - const content = - 'This is [a] test sentence with [citations [1]] and [[2]] and [3].'; - const citations = ['https://example1.com', 'https://example2.com']; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - const result = parser.parse(content) + parser.end(); - - const expected = [ - 'This is [a] test sentence with [citations [^1]] and [^2] and [3].', - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - ].join('\n'); - - t.is(result, expected); -}); - -test('CitationParser should replace chunks of citation placeholders with URLs', t => { - const contents = [ - '[[]]', - 'This is [', - 'a] test sentence ', - 'with citations [1', - '] and [', - '[2]] and [[', - '3]] and [[4', - ']] and [[5]', - '] and [[6]]', - ' and [7', - ]; - const citations = [ - 'https://example1.com', - 'https://example2.com', - 'https://example3.com', - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - 'https://example7.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - '[[]]This is [a] test sentence with citations [^1] and [^2] and [^3] and [^4] and [^5] and [^6] and [7', - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - `[^4]: {"type":"url","url":"${encodeURIComponent(citations[3])}"}`, - `[^5]: {"type":"url","url":"${encodeURIComponent(citations[4])}"}`, - `[^6]: {"type":"url","url":"${encodeURIComponent(citations[5])}"}`, - `[^7]: {"type":"url","url":"${encodeURIComponent(citations[6])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should not replace citation already with URLs', t => { - const content = - 'This is [a] test sentence with citations [1](https://example1.com) and [[2]](https://example2.com) and [[3](https://example3.com)].'; - const citations = [ - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - const result = parser.parse(content) + parser.end(); - - const expected = [ - content, - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should not replace chunks of citation already with URLs', t => { - const contents = [ - 'This is [a] test sentence with citations [1', - '](https://example1.com) and [[2]', - '](https://example2.com) and [[3](https://example3.com)].', - ]; - const citations = [ - 'https://example4.com', - 'https://example5.com', - 'https://example6.com', - ]; - - const parser = new CitationParser(); - for (const citation of citations) { - parser.push(citation); - } - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - contents.join(''), - `[^1]: {"type":"url","url":"${encodeURIComponent(citations[0])}"}`, - `[^2]: {"type":"url","url":"${encodeURIComponent(citations[1])}"}`, - `[^3]: {"type":"url","url":"${encodeURIComponent(citations[2])}"}`, - ].join('\n'); - t.is(result, expected); -}); - -test('CitationParser should replace openai style reference chunks', t => { - const contents = [ - 'This is [a] test sentence with citations ', - '([example1.com](https://example1.com))', - ]; - - const parser = new CitationParser(); - - let result = contents.reduce((acc, current) => { - return acc + parser.parse(current); - }, ''); - result += parser.end(); - - const expected = [ - contents[0] + '[^1]', - `[^1]: {"type":"url","url":"${encodeURIComponent('https://example1.com')}"}`, - ].join('\n'); - t.is(result, expected); -}); - test('TextStreamParser should format different types of chunks correctly', t => { // Define interfaces for fixtures interface BaseFixture { @@ -2063,25 +2001,23 @@ test('should handle copilot cron jobs correctly', async t => { }); test('should resolve model correctly based on subscription status and prompt config', async t => { - const { db, session, subscription } = t.context; + const { prompt, session, subscription } = t.context; // 1) Seed a prompt that has optionalModels and proModels in config const promptName = 'resolve-model-test'; - await db.aiPrompt.create({ - data: { - name: promptName, - model: 'gemini-2.5-flash', - messages: { - create: [{ idx: 0, role: 'system', content: 'test' }], - }, - config: { proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] }, + await prompt.set( + promptName, + 'gemini-2.5-flash', + [{ role: 'system', content: 'test' }], + { proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'] }, + { optionalModels: [ 'gemini-2.5-flash', 'gemini-2.5-pro', 'claude-sonnet-4-5@20250929', ], - }, - }); + } + ); // 2) Create a chat session with this prompt const sessionId = await session.create({ @@ -2106,6 +2042,16 @@ test('should resolve model correctly based on subscription status and prompt con const model1 = await s.resolveModel(false, 'gemini-2.5-pro'); t.snapshot(model1, 'should honor requested pro model'); + const model1WithPrefix = await s.resolveModel( + false, + 'openai-default/gemini-2.5-pro' + ); + t.is( + model1WithPrefix, + 'openai-default/gemini-2.5-pro', + 'should honor requested prefixed pro model' + ); + const model2 = await s.resolveModel(false, 'not-in-optional'); t.snapshot(model2, 'should fallback to default model'); } @@ -2119,6 +2065,16 @@ test('should resolve model correctly based on subscription status and prompt con 'should fallback to default model when requesting pro model during trialing' ); + const model3WithPrefix = await s.resolveModel( + true, + 'openai-default/gemini-2.5-pro' + ); + t.is( + model3WithPrefix, + 'gemini-2.5-flash', + 'should fallback to default model when requesting prefixed pro model during trialing' + ); + const model4 = await s.resolveModel(true, 'gemini-2.5-flash'); t.snapshot(model4, 'should honor requested non-pro model during trialing'); @@ -2141,6 +2097,16 @@ test('should resolve model correctly based on subscription status and prompt con const model7 = await s.resolveModel(true, 'claude-sonnet-4-5@20250929'); t.snapshot(model7, 'should honor requested pro model during active'); + const model7WithPrefix = await s.resolveModel( + true, + 'openai-default/claude-sonnet-4-5@20250929' + ); + t.is( + model7WithPrefix, + 'openai-default/claude-sonnet-4-5@20250929', + 'should honor requested prefixed pro model during active' + ); + const model8 = await s.resolveModel(true, 'not-in-optional'); t.snapshot( model8, diff --git a/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts b/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts new file mode 100644 index 0000000000..1f66c5f2d8 --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/native-provider.spec.ts @@ -0,0 +1,1431 @@ +import test from 'ava'; +import { z } from 'zod'; + +import { CopilotPromptInvalid, CopilotProviderSideError } from '../../base'; +import type { + NativeLlmBackendConfig, + NativeLlmEmbeddingRequest, + NativeLlmEmbeddingResponse, + NativeLlmRequest, + NativeLlmRerankRequest, + NativeLlmRerankResponse, + NativeLlmStreamEvent, + NativeLlmStructuredRequest, + NativeLlmStructuredResponse, +} from '../../native'; +import { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; +import { GeminiProvider } from '../../plugins/copilot/providers/gemini/gemini'; +import { GeminiVertexProvider } from '../../plugins/copilot/providers/gemini/vertex'; +import { + buildNativeRequest, + NativeProviderAdapter, +} from '../../plugins/copilot/providers/native'; +import { OpenAIProvider } from '../../plugins/copilot/providers/openai'; +import { PerplexityProvider } from '../../plugins/copilot/providers/perplexity'; +import { + CopilotProviderType, + ModelInputType, + ModelOutputType, + type PromptMessage, +} from '../../plugins/copilot/providers/types'; +import type { CopilotToolSet } from '../../plugins/copilot/tools'; + +const mockDispatch = () => + (async function* (): AsyncIterableIterator { + yield { type: 'text_delta', text: 'Use [^1] now' }; + yield { type: 'citation', index: 1, url: 'https://affine.pro' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + +function stream( + factory: () => NativeLlmStreamEvent[] +): AsyncIterableIterator { + return (async function* () { + for (const event of factory()) { + yield event; + } + })(); +} + +class TestGeminiProvider extends GeminiProvider<{ apiKey: string }> { + override readonly type = CopilotProviderType.Gemini; + override readonly models = [ + { + id: 'gemini-2.5-flash', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ModelInputType.File, + ], + output: [ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ], + }, + ], + }, + { + id: 'gemini-embedding-001', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + ]; + readonly dispatchRequests: NativeLlmRequest[] = []; + readonly structuredRequests: NativeLlmStructuredRequest[] = []; + readonly embeddingRequests: NativeLlmEmbeddingRequest[] = []; + readonly remoteAttachmentRequests: string[] = []; + readonly remoteAttachmentSignals: Array = []; + readonly retryDelays: number[] = []; + remoteAttachmentResponses = new Map< + string, + { data: string; mimeType: string } + >(); + testTools: CopilotToolSet = {}; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }; + dispatchFactory: (request: NativeLlmRequest) => NativeLlmStreamEvent[] = + () => [ + { type: 'text_delta', text: 'native' }, + { type: 'done', finish_reason: 'stop' }, + ]; + structuredFactory: ( + request: NativeLlmStructuredRequest + ) => NativeLlmStructuredResponse = () => ({ + id: 'structured_1', + model: 'gemini-2.5-flash', + output_text: '{"summary":"AFFiNE native"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }); + embeddingFactory: ( + request: NativeLlmEmbeddingRequest + ) => NativeLlmEmbeddingResponse = request => ({ + model: request.model, + embeddings: request.inputs.map((_, index) => [index + 0.1, index + 0.2]), + usage: { + prompt_tokens: request.inputs.length, + total_tokens: request.inputs.length, + }, + }); + + override configured() { + return true; + } + + protected override async createNativeConfig(): Promise { + return { + base_url: 'https://generativelanguage.googleapis.com/v1beta', + auth_token: 'api-key', + request_layer: 'gemini_api', + }; + } + + protected override createNativeDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmRequest) => { + this.dispatchRequests.push(request); + return stream(() => this.dispatchFactory(request)); + }; + } + + protected override createNativeStructuredDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmStructuredRequest) => { + this.structuredRequests.push(request); + return this.structuredFactory(request); + }; + } + + protected override createNativeEmbeddingDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmEmbeddingRequest) => { + this.embeddingRequests.push(request); + return this.embeddingFactory(request); + }; + } + + protected override async fetchRemoteAttach( + url: string, + signal?: AbortSignal + ) { + this.remoteAttachmentRequests.push(url); + this.remoteAttachmentSignals.push(signal); + const response = this.remoteAttachmentResponses.get(url); + if (!response) { + throw new Error(`missing remote attachment stub for ${url}`); + } + return response; + } + + protected override async waitForStructuredRetry(delayMs: number) { + this.retryDelays.push(delayMs); + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override async getTools(): Promise { + return this.testTools; + } +} + +class TestGeminiVertexProvider extends GeminiVertexProvider { + testConfig = { + location: 'us-central1', + project: 'p1', + googleAuthOptions: {}, + } as any; + readonly dispatchRequests: NativeLlmRequest[] = []; + readonly remoteAttachmentRequests: string[] = []; + readonly remoteAttachmentSignals: Array = []; + remoteAttachmentResponses = new Map< + string, + { data: string; mimeType: string } + >(); + testTools: CopilotToolSet = {}; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }; + + override get config() { + return this.testConfig; + } + + override configured() { + return true; + } + + protected override async resolveVertexAuth() { + return { + baseUrl: 'https://vertex.example', + headers: () => ({ + Authorization: 'Bearer vertex-token', + 'x-goog-user-project': 'p1', + }), + fetch: undefined, + } as const; + } + + protected override createNativeDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmRequest) => { + this.dispatchRequests.push(request); + return stream(() => [ + { type: 'text_delta', text: 'vertex native' }, + { type: 'done', finish_reason: 'stop' }, + ]); + }; + } + + // eslint-disable-next-line sonarjs/no-identical-functions + protected override async fetchRemoteAttach( + url: string, + signal?: AbortSignal + ) { + this.remoteAttachmentRequests.push(url); + this.remoteAttachmentSignals.push(signal); + const response = this.remoteAttachmentResponses.get(url); + if (!response) { + throw new Error(`missing remote attachment stub for ${url}`); + } + return response; + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override async getTools(): Promise { + return this.testTools; + } + + async exposeNativeConfig() { + return await this.createNativeConfig(); + } +} + +class TestOpenAIProvider extends OpenAIProvider { + override readonly models = [ + { + id: 'gpt-4.1', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ + ModelOutputType.Text, + ModelOutputType.Structured, + ModelOutputType.Rerank, + ], + }, + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + { + id: 'gpt-5.2', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ + ModelOutputType.Text, + ModelOutputType.Structured, + ModelOutputType.Rerank, + ], + }, + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + { + id: 'text-embedding-3-small', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + ]; + + readonly structuredRequests: NativeLlmStructuredRequest[] = []; + readonly embeddingRequests: NativeLlmEmbeddingRequest[] = []; + readonly rerankRequests: NativeLlmRerankRequest[] = []; + testMiddleware: ProviderMiddlewareConfig = { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + }; + + override get config() { + return { + apiKey: 'openai-key', + baseURL: 'https://api.openai.com/v1', + }; + } + + override configured() { + return true; + } + + protected override getActiveProviderMiddleware(): ProviderMiddlewareConfig { + return this.testMiddleware; + } + + protected override createNativeStructuredDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmStructuredRequest) => { + this.structuredRequests.push(request); + return { + id: 'structured_openai_1', + model: request.model, + output_text: '{"summary":"AFFiNE structured"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }; + }; + } + + protected override createNativeEmbeddingDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmEmbeddingRequest) => { + this.embeddingRequests.push(request); + return { + model: request.model, + embeddings: request.inputs.map(() => [0.4, 0.5]), + usage: { + prompt_tokens: request.inputs.length, + total_tokens: request.inputs.length, + }, + }; + }; + } + + protected override createNativeRerankDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async (request: NativeLlmRerankRequest) => { + this.rerankRequests.push(request); + return { + model: request.model, + scores: request.candidates.map(() => 0.8), + } satisfies NativeLlmRerankResponse; + }; + } +} + +class TestPerplexityProvider extends PerplexityProvider { + override get config() { + return { apiKey: 'perplexity-key' }; + } + + override configured() { + return true; + } +} + +test('NativeProviderAdapter streamText should append citation footnotes', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); + const chunks: string[] = []; + for await (const chunk of adapter.streamText({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks.join(''); + t.true(text.includes('Use [^1] now')); + t.true( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('NativeProviderAdapter streamObject should append citation footnotes', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3); + const chunks = []; + for await (const chunk of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + t.deepEqual( + chunks.map(chunk => chunk.type), + ['text-delta', 'text-delta'] + ); + const text = chunks + .filter(chunk => chunk.type === 'text-delta') + .map(chunk => chunk.textDelta) + .join(''); + t.true(text.includes('Use [^1] now')); + t.true( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('NativeProviderAdapter streamObject should append fallback attachment footnotes', async t => { + const dispatch = () => + (async function* (): AsyncIterableIterator { + yield { + type: 'tool_result', + call_id: 'call_1', + name: 'blob_read', + arguments: { blob_id: 'blob_1' }, + output: { + blobId: 'blob_1', + fileName: 'a.txt', + fileType: 'text/plain', + content: 'A', + }, + }; + yield { + type: 'tool_result', + call_id: 'call_2', + name: 'blob_read', + arguments: { blob_id: 'blob_2' }, + output: { + blobId: 'blob_2', + fileName: 'b.txt', + fileType: 'text/plain', + content: 'B', + }, + }; + yield { type: 'text_delta', text: 'Answer from files.' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + + const adapter = new NativeProviderAdapter(dispatch, {}, 3); + const chunks = []; + for await (const chunk of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks + .filter(chunk => chunk.type === 'text-delta') + .map(chunk => chunk.textDelta) + .join(''); + t.true(text.includes('Answer from files.')); + t.true(text.includes('[^1][^2]')); + t.true( + text.includes( + '[^1]: {"type":"attachment","blobId":"blob_1","fileName":"a.txt","fileType":"text/plain"}' + ) + ); + t.true( + text.includes( + '[^2]: {"type":"attachment","blobId":"blob_2","fileName":"b.txt","fileType":"text/plain"}' + ) + ); +}); + +test('NativeProviderAdapter streamObject should map tool and text events', async t => { + let round = 0; + const dispatch = (_request: NativeLlmRequest) => + (async function* (): AsyncIterableIterator { + round += 1; + if (round === 1) { + yield { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }; + yield { type: 'done', finish_reason: 'tool_calls' }; + return; + } + yield { type: 'text_delta', text: 'ok' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + + const adapter = new NativeProviderAdapter( + dispatch, + { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async () => ({ markdown: '# a1' }), + }, + }, + 4 + ); + + const events = []; + for await (const event of adapter.streamObject({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'read' }] }], + })) { + events.push(event); + } + + t.deepEqual( + events.map(event => event.type), + ['tool-call', 'tool-result', 'text-delta'] + ); + t.deepEqual(events[0], { + type: 'tool-call', + toolCallId: 'call_1', + toolName: 'doc_read', + args: { doc_id: 'a1' }, + }); +}); + +test('buildNativeRequest should include rust middleware from profile', async t => { + const { request } = await buildNativeRequest({ + model: 'gpt-5-mini', + messages: [{ role: 'user', content: 'hello' }], + tools: {}, + middleware: { + rust: { + request: ['normalize_messages', 'clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['callout'], + }, + }, + }); + + t.deepEqual(request.middleware, { + request: ['normalize_messages', 'clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }); +}); + +test('buildNativeRequest should preserve non-image attachment urls for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'summarize this attachment', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'summarize this attachment' }, + { + type: 'file', + source: { + url: 'https://example.com/a.pdf', + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('buildNativeRequest should inline data url attachments for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'read this note', + attachments: ['data:text/plain,hello%20world'], + params: { mimetype: 'text/plain' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'read this note' }, + { + type: 'file', + source: { + media_type: 'text/plain', + data: Buffer.from('hello world', 'utf8').toString('base64'), + }, + }, + ]); +}); + +test('buildNativeRequest should classify audio attachments for native Gemini', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'transcribe this clip', + attachments: ['https://example.com/a.mp3'], + params: { mimetype: 'audio/mpeg' }, + }, + ], + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'transcribe this clip' }, + { + type: 'audio', + source: { + url: 'https://example.com/a.mp3', + media_type: 'audio/mpeg', + }, + }, + ]); +}); + +test('buildNativeRequest should preserve bytes and file handle attachment sources', async t => { + const { request } = await buildNativeRequest({ + model: 'gemini-2.5-flash', + messages: [ + { + role: 'user', + content: 'inspect these assets', + attachments: [ + { + kind: 'bytes', + data: Buffer.from('hello', 'utf8').toString('base64'), + mimeType: 'text/plain', + fileName: 'hello.txt', + }, + { + kind: 'file_handle', + fileHandle: 'file_123', + mimeType: 'application/pdf', + fileName: 'report.pdf', + }, + ], + }, + ], + attachmentCapability: { + kinds: ['image', 'audio', 'file'], + sourceKinds: ['bytes', 'file_handle'], + }, + }); + + t.deepEqual(request.messages[0]?.content, [ + { type: 'text', text: 'inspect these assets' }, + { + type: 'file', + source: { + media_type: 'text/plain', + data: Buffer.from('hello', 'utf8').toString('base64'), + file_name: 'hello.txt', + }, + }, + { + type: 'file', + source: { + file_handle: 'file_123', + media_type: 'application/pdf', + file_name: 'report.pdf', + }, + }, + ]); +}); + +test('buildNativeRequest should reject attachments outside native admission matrix', async t => { + const error = await t.throwsAsync( + buildNativeRequest({ + model: 'gpt-4o', + messages: [ + { + role: 'user', + content: 'summarize this attachment', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + attachmentCapability: { + kinds: ['image'], + sourceKinds: ['url', 'data'], + allowRemoteUrls: true, + }, + }) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /does not support file attachments/i); +}); + +test('buildNativeStructuredRequest should prefer explicit schema option', async t => { + const provider = new TestOpenAIProvider(); + const schema = z.object({ summary: z.string() }); + + await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one sentence.', + }, + ], + { schema } + ); + + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { summary: { type: 'string' } }, + required: ['summary'], + additionalProperties: false, + }); +}); + +test('buildNativeStructuredRequest should preserve caller strictness override', async t => { + const provider = new TestOpenAIProvider(); + + await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { role: 'system', content: 'Return JSON only.' }, + { role: 'user', content: 'Summarize AFFiNE in one sentence.' }, + ], + { schema: z.object({ summary: z.string() }), strict: false } + ); + + t.is(provider.structuredRequests[0]?.strict, false); +}); + +test('NativeProviderAdapter streamText should skip citation footnotes when disabled', async t => { + const adapter = new NativeProviderAdapter(mockDispatch, {}, 3, { + nodeTextMiddleware: ['callout'], + }); + const chunks: string[] = []; + for await (const chunk of adapter.streamText({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'hi' }] }], + })) { + chunks.push(chunk); + } + + const text = chunks.join(''); + t.true(text.includes('Use [^1] now')); + t.false( + text.includes('[^1]: {"type":"url","url":"https%3A%2F%2Faffine.pro"}') + ); +}); + +test('GeminiProvider should use native path for text-only requests', async t => { + const provider = new TestGeminiProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [{ role: 'user', content: 'hello' }], + { reasoning: true } + ); + + t.is(result, 'native'); + t.is(provider.dispatchRequests.length, 1); + t.deepEqual(provider.dispatchRequests[0]?.reasoning, { + include_thoughts: true, + thinking_budget: 12000, + }); + t.deepEqual(provider.dispatchRequests[0]?.middleware, { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }); +}); + +test('GeminiProvider should use native path for structured requests', async t => { + const provider = new TestGeminiProvider(); + + const schema = z.object({ summary: z.string() }); + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema } + ); + + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { + summary: { + type: 'string', + }, + }, + required: ['summary'], + additionalProperties: false, + }); + t.deepEqual(JSON.parse(result), { summary: 'AFFiNE native' }); +}); + +test('GeminiProvider should retry only reparsable structured responses', async t => { + const provider = new TestGeminiProvider(); + let attempts = 0; + provider.structuredFactory = () => { + attempts += 1; + return { + id: `structured_retry_${attempts}`, + model: 'gemini-2.5-flash', + output_text: + attempts === 1 ? '```json\n{"summary":1}\n```' : '{"summary":"ok"}', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }; + }; + + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema: z.object({ summary: z.string() }), maxRetries: 2 } + ); + + t.is(attempts, 2); + t.deepEqual(JSON.parse(result), { summary: 'ok' }); +}); + +test('GeminiProvider should treat maxRetries as retry count for backend failures', async t => { + const provider = new TestGeminiProvider(); + let attempts = 0; + provider.structuredFactory = () => { + attempts += 1; + throw new Error('backend down'); + }; + + const error = await t.throwsAsync( + provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one short sentence.', + }, + ], + { schema: z.object({ summary: z.string() }), maxRetries: 2 } + ) + ); + + t.is(attempts, 3); + t.deepEqual(provider.retryDelays, [2_000, 4_000]); + t.regex(error.message, /backend down/); +}); + +test('GeminiProvider should use native structured path for audio attachments', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('audio-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.mp3', { + data: inlineData, + mimeType: 'audio/mpeg', + }); + provider.structuredFactory = () => ({ + id: 'structured_audio_1', + model: 'gemini-2.5-flash', + output_text: '[{"a":"Speaker 1","s":0,"e":1,"t":"Hello"}]', + usage: { + prompt_tokens: 4, + completion_tokens: 3, + total_tokens: 7, + }, + finish_reason: 'stop', + }); + + const result = await provider.structure( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'transcribe the audio', + attachments: ['https://example.com/a.mp3'], + params: { mimetype: 'audio/mpeg' }, + }, + ], + { + schema: z.array( + z.object({ a: z.string(), s: z.number(), e: z.number(), t: z.string() }) + ), + } + ); + + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.messages[1]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + data: inlineData, + media_type: 'audio/mpeg', + }, + }, + ]); + t.deepEqual(provider.remoteAttachmentRequests, ['https://example.com/a.mp3']); + t.deepEqual(JSON.parse(result), [{ a: 'Speaker 1', s: 0, e: 1, t: 'Hello' }]); +}); + +test('GeminiProvider should use native path for embeddings', async t => { + const provider = new TestGeminiProvider(); + + const result = await provider.embedding( + { modelId: 'gemini-embedding-001' }, + ['first', 'second'], + { dimensions: 3 } + ); + + t.deepEqual(result, [ + [0.1, 0.2], + [1.1, 1.2], + ]); + t.is(provider.embeddingRequests.length, 1); + t.deepEqual(provider.embeddingRequests[0], { + model: 'gemini-embedding-001', + inputs: ['first', 'second'], + dimensions: 3, + task_type: 'RETRIEVAL_DOCUMENT', + }); +}); + +test('GeminiProvider should use native path for non-image attachments', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('pdf-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.pdf', { + data: inlineData, + mimeType: 'application/pdf', + }); + const messages: PromptMessage[] = [ + { + role: 'user', + content: 'summarize this file', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ]; + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + messages, + {} + ); + + t.is(result, 'native'); + t.is(provider.dispatchRequests.length, 1); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'summarize this file' }, + { + type: 'file', + source: { + data: inlineData, + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('GeminiProvider should inline remote image attachments for text requests', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('image-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.jpg', { + data: inlineData, + mimeType: 'image/jpeg', + }); + + const result = await provider.text({ modelId: 'gemini-2.5-flash' }, [ + { + role: 'user', + content: 'describe this image', + attachments: ['https://example.com/a.jpg'], + }, + ]); + + t.is(result, 'native'); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'describe this image' }, + { + type: 'image', + source: { + data: inlineData, + media_type: 'image/jpeg', + }, + }, + ]); +}); + +test('GeminiProvider should pass abort signal to remote attachment prefetch', async t => { + const provider = new TestGeminiProvider(); + provider.remoteAttachmentResponses.set('https://example.com/a.jpg', { + data: Buffer.from('image-bytes', 'utf8').toString('base64'), + mimeType: 'image/jpeg', + }); + const controller = new AbortController(); + + await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'describe this image', + attachments: ['https://example.com/a.jpg'], + }, + ], + { signal: controller.signal } + ); + + t.deepEqual(provider.remoteAttachmentRequests, ['https://example.com/a.jpg']); + t.is(provider.remoteAttachmentSignals[0], controller.signal); +}); + +test('GeminiProvider should classify downloaded audio-only WebM attachments as audio', async t => { + const provider = new TestGeminiProvider(); + const inlineData = Buffer.from('audio-bytes', 'utf8').toString('base64'); + provider.remoteAttachmentResponses.set('https://example.com/a.webm', { + data: inlineData, + mimeType: 'audio/webm', + }); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe this clip', + attachments: ['https://example.com/a.webm'], + }, + ], + {} + ); + + t.is(result, 'native'); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe this clip' }, + { type: 'audio', source: { data: inlineData, media_type: 'audio/webm' } }, + ]); +}); + +test('GeminiProvider should preserve Google file urls for native Gemini API', async t => { + const provider = new TestGeminiProvider(); + + await provider.text({ modelId: 'gemini-2.5-flash' }, [ + { + role: 'user', + content: 'summarize this file', + attachments: [ + 'https://generativelanguage.googleapis.com/v1beta/files/file-123', + ], + params: { mimetype: 'application/pdf' }, + }, + ]); + + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'summarize this file' }, + { + type: 'file', + source: { + url: 'https://generativelanguage.googleapis.com/v1beta/files/file-123', + media_type: 'application/pdf', + }, + }, + ]); +}); + +test('PerplexityProvider should ignore attachments during text model matching', async t => { + const provider = new TestPerplexityProvider(); + let capturedRequest: NativeLlmRequest | undefined; + + (provider as any).getActiveProviderMiddleware = () => ({}); + (provider as any).getTools = async () => ({}); + (provider as any).createNativeAdapter = () => ({ + text: async (request: NativeLlmRequest) => { + capturedRequest = request; + return 'ok'; + }, + }); + + const result = await provider.text( + { modelId: 'sonar' }, + [ + { + role: 'user', + content: 'summarize this', + attachments: ['https://example.com/a.pdf'], + params: { mimetype: 'application/pdf' }, + }, + ], + {} + ); + + t.is(result, 'ok'); + t.deepEqual(capturedRequest?.messages[0]?.content, [ + { type: 'text', text: 'summarize this' }, + ]); +}); + +test('GeminiProvider should reject unsupported attachment schemes at input validation', async t => { + const provider = new TestGeminiProvider(); + + const error = await t.throwsAsync( + provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'read this attachment', + attachments: ['blob:https://example.com/file-id'], + params: { mimetype: 'application/pdf' }, + }, + ], + {} + ) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /attachments must use https\?:\/\/, gs:\/\/ or data:/); + t.is(provider.dispatchRequests.length, 0); +}); + +test('GeminiProvider should validate malformed attachments before canonicalization', async t => { + const provider = new TestGeminiProvider(); + + const error = await t.throwsAsync( + provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'read this attachment', + attachments: [{ kind: 'url' }], + }, + ] as any, + {} + ) + ); + + t.true(error instanceof CopilotPromptInvalid); + t.regex(error.message, /attachments\[0\]/); + t.is(provider.dispatchRequests.length, 0); +}); + +test('GeminiProvider should drive tool loop on native path', async t => { + const provider = new TestGeminiProvider(); + provider.testTools = { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async args => ({ markdown: `# ${(args as any).doc_id}` }), + }, + }; + provider.dispatchFactory = request => { + const hasToolResult = request.messages.some( + message => message.role === 'tool' + ); + if (!hasToolResult) { + return [ + { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }, + { type: 'done', finish_reason: 'tool_calls' }, + ]; + } + + return [ + { type: 'text_delta', text: 'after tool' }, + { type: 'done', finish_reason: 'stop' }, + ]; + }; + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [{ role: 'user', content: 'read doc a1' }], + {} + ); + + t.true(result.includes('after tool')); + t.is(provider.dispatchRequests.length, 2); + t.true( + provider.dispatchRequests[1]?.messages.some( + message => message.role === 'tool' + ) + ); +}); + +test('GeminiVertexProvider should prefetch bearer token for native config', async t => { + const provider = new TestGeminiVertexProvider(); + + const config = await provider.exposeNativeConfig(); + + t.deepEqual(config, { + base_url: 'https://vertex.example', + auth_token: 'vertex-token', + request_layer: 'gemini_vertex', + }); +}); + +test('GeminiVertexProvider should preserve remote http attachments like Vertex SDK', async t => { + const provider = new TestGeminiVertexProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe the audio', + attachments: ['https://example.com/a.mp3'], + }, + ], + {} + ); + + t.is(result, 'vertex native'); + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + url: 'https://example.com/a.mp3', + media_type: 'audio/mpeg', + }, + }, + ]); +}); + +test('GeminiVertexProvider should preserve gs urls for native Vertex requests', async t => { + const provider = new TestGeminiVertexProvider(); + + const result = await provider.text( + { modelId: 'gemini-2.5-flash' }, + [ + { + role: 'user', + content: 'transcribe the audio', + attachments: ['gs://bucket/audio.opus'], + }, + ], + {} + ); + + t.is(result, 'vertex native'); + t.deepEqual(provider.remoteAttachmentRequests, []); + t.deepEqual(provider.dispatchRequests[0]?.messages[0]?.content, [ + { type: 'text', text: 'transcribe the audio' }, + { + type: 'audio', + source: { + url: 'gs://bucket/audio.opus', + media_type: 'audio/opus', + }, + }, + ]); +}); + +test('OpenAIProvider should use native structured dispatch', async t => { + const provider = new TestOpenAIProvider(); + const schema = z.object({ summary: z.string() }); + + const result = await provider.structure( + { modelId: 'gpt-4.1' }, + [ + { + role: 'system', + content: 'Return JSON only.', + }, + { + role: 'user', + content: 'Summarize AFFiNE in one sentence.', + }, + ], + { schema } + ); + + t.deepEqual(JSON.parse(result), { summary: 'AFFiNE structured' }); + t.is(provider.structuredRequests.length, 1); + t.deepEqual(provider.structuredRequests[0]?.schema, { + type: 'object', + properties: { + summary: { + type: 'string', + }, + }, + required: ['summary'], + additionalProperties: false, + }); +}); + +test('OpenAIProvider should use native embedding dispatch', async t => { + const provider = new TestOpenAIProvider(); + + const result = await provider.embedding( + { modelId: 'text-embedding-3-small' }, + ['alpha', 'beta'], + { dimensions: 8 } + ); + + t.deepEqual(result, [ + [0.4, 0.5], + [0.4, 0.5], + ]); + t.is(provider.embeddingRequests.length, 1); + t.deepEqual(provider.embeddingRequests[0], { + model: 'text-embedding-3-small', + inputs: ['alpha', 'beta'], + dimensions: 8, + task_type: 'RETRIEVAL_DOCUMENT', + }); +}); + +test('OpenAIProvider should use native rerank dispatch', async t => { + const provider = new TestOpenAIProvider(); + + const scores = await provider.rerank( + { modelId: 'gpt-4.1' }, + { + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The park is sunny today.' }, + ], + } + ); + + t.deepEqual(scores, [0.8, 0.8]); + t.is(provider.rerankRequests.length, 1); + t.is(provider.rerankRequests[0]?.model, 'gpt-4.1'); + t.is(provider.rerankRequests[0]?.query, 'programming'); + t.deepEqual(provider.rerankRequests[0]?.candidates, [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The park is sunny today.' }, + ]); +}); + +test('OpenAIProvider rerank should normalize native dispatch errors', async t => { + class ErroringOpenAIProvider extends TestOpenAIProvider { + protected override createNativeRerankDispatch( + _backendConfig: NativeLlmBackendConfig + ) { + return async () => { + throw new Error('native rerank exploded'); + }; + } + } + + const provider = new ErroringOpenAIProvider(); + + const error = await t.throwsAsync( + provider.rerank( + { modelId: 'gpt-4.1' }, + { + query: 'programming', + candidates: [{ id: 'react', text: 'React is a UI library.' }], + } + ) + ); + + t.true(error instanceof CopilotProviderSideError); + t.regex(error.message, /native rerank exploded/i); +}); diff --git a/packages/backend/server/src/__tests__/copilot/provider-middleware.spec.ts b/packages/backend/server/src/__tests__/copilot/provider-middleware.spec.ts new file mode 100644 index 0000000000..81bda1a604 --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/provider-middleware.spec.ts @@ -0,0 +1,56 @@ +import test from 'ava'; + +import { resolveProviderMiddleware } from '../../plugins/copilot/providers/provider-middleware'; +import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry'; +import { CopilotProviderType } from '../../plugins/copilot/providers/types'; + +test('resolveProviderMiddleware should include anthropic defaults', t => { + const middleware = resolveProviderMiddleware(CopilotProviderType.Anthropic); + + t.deepEqual(middleware.rust?.request, [ + 'normalize_messages', + 'tool_schema_rewrite', + ]); + t.deepEqual(middleware.rust?.stream, [ + 'stream_event_normalize', + 'citation_indexing', + ]); + t.deepEqual(middleware.node?.text, ['citation_footnote', 'callout']); +}); + +test('resolveProviderMiddleware should merge defaults and overrides', t => { + const middleware = resolveProviderMiddleware(CopilotProviderType.OpenAI, { + rust: { request: ['clamp_max_tokens'] }, + node: { text: ['thinking_format'] }, + }); + + t.deepEqual(middleware.rust?.request, [ + 'normalize_messages', + 'clamp_max_tokens', + ]); + t.deepEqual(middleware.node?.text, [ + 'citation_footnote', + 'callout', + 'thinking_format', + ]); +}); + +test('buildProviderRegistry should normalize profile middleware defaults', t => { + const registry = buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '1' }, + }, + ], + }); + + const profile = registry.profiles.get('openai-main'); + t.truthy(profile); + t.deepEqual(profile?.middleware.rust?.stream, [ + 'stream_event_normalize', + 'citation_indexing', + ]); + t.deepEqual(profile?.middleware.node?.text, ['citation_footnote', 'callout']); +}); diff --git a/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts b/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts new file mode 100644 index 0000000000..58e04d4f63 --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/provider-native.spec.ts @@ -0,0 +1,200 @@ +import serverNativeModule from '@affine/server-native'; +import test from 'ava'; + +import type { NativeLlmRerankRequest } from '../../native'; +import { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; +import { + normalizeOpenAIOptionsForModel, + OpenAIProvider, +} from '../../plugins/copilot/providers/openai'; +import { CopilotProvider } from '../../plugins/copilot/providers/provider'; +import { + CopilotProviderType, + ModelInputType, + ModelOutputType, +} from '../../plugins/copilot/providers/types'; + +class TestOpenAIProvider extends CopilotProvider<{ apiKey: string }> { + readonly type = CopilotProviderType.OpenAI; + readonly models = [ + { + id: 'gpt-5-mini', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Text], + defaultForOutputType: true, + }, + ], + }, + ]; + + configured() { + return true; + } + + async text(_cond: any, _messages: any[], _options?: any) { + return ''; + } + + async *streamText(_cond: any, _messages: any[], _options?: any) { + yield ''; + } + + exposeMetricLabels() { + return this.metricLabels('gpt-5-mini'); + } + + exposeMiddleware() { + return this.getActiveProviderMiddleware(); + } +} + +class NativeRerankProtocolProvider extends OpenAIProvider { + override readonly models = [ + { + id: 'gpt-5.2', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Text, ModelOutputType.Rerank], + defaultForOutputType: true, + }, + ], + }, + ]; + + override get config() { + return { + apiKey: 'test-key', + baseURL: 'https://api.openai.com/v1', + oldApiStyle: false, + }; + } + + override configured() { + return true; + } +} + +function createProvider(profileMiddleware?: ProviderMiddlewareConfig) { + const provider = new TestOpenAIProvider(); + (provider as any).AFFiNEConfig = { + copilot: { + providers: { + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: 'test' }, + middleware: profileMiddleware, + }, + ], + defaults: {}, + openai: { apiKey: 'legacy' }, + }, + }, + }; + return provider; +} + +test('metricLabels should include active provider id', t => { + const provider = createProvider(); + const labels = provider.runWithProfile('openai-main', () => + provider.exposeMetricLabels() + ); + t.is(labels.providerId, 'openai-main'); +}); + +test('getActiveProviderMiddleware should merge defaults with profile override', t => { + const provider = createProvider({ + rust: { request: ['clamp_max_tokens'] }, + node: { text: ['thinking_format'] }, + }); + + const middleware = provider.runWithProfile('openai-main', () => + provider.exposeMiddleware() + ); + + t.deepEqual(middleware.rust?.request, [ + 'normalize_messages', + 'clamp_max_tokens', + ]); + t.deepEqual(middleware.rust?.stream, [ + 'stream_event_normalize', + 'citation_indexing', + ]); + t.deepEqual(middleware.node?.text, [ + 'citation_footnote', + 'callout', + 'thinking_format', + ]); +}); + +test('normalizeOpenAIOptionsForModel should drop sampling knobs for gpt-5.2', t => { + t.deepEqual( + normalizeOpenAIOptionsForModel( + { + temperature: 0.7, + topP: 0.8, + presencePenalty: 0.2, + frequencyPenalty: 0.1, + maxTokens: 128, + }, + 'gpt-5.4' + ), + { maxTokens: 128 } + ); +}); + +test('normalizeOpenAIOptionsForModel should keep options for gpt-4.1', t => { + t.deepEqual( + normalizeOpenAIOptionsForModel( + { temperature: 0.7, topP: 0.8, maxTokens: 128 }, + 'gpt-4.1' + ), + { temperature: 0.7, topP: 0.8, maxTokens: 128 } + ); +}); + +test('OpenAI rerank should always use chat-completions native protocol', async t => { + const provider = new NativeRerankProtocolProvider(); + let capturedProtocol: string | undefined; + let capturedRequest: NativeLlmRerankRequest | undefined; + + const original = (serverNativeModule as any).llmRerankDispatch; + (serverNativeModule as any).llmRerankDispatch = ( + protocol: string, + _backendConfigJson: string, + requestJson: string + ) => { + capturedProtocol = protocol; + capturedRequest = JSON.parse(requestJson) as NativeLlmRerankRequest; + return JSON.stringify({ model: 'gpt-5.2', scores: [0.9, 0.1] }); + }; + t.teardown(() => { + (serverNativeModule as any).llmRerankDispatch = original; + }); + + const scores = await provider.rerank( + { modelId: 'gpt-5.2' }, + { + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The weather is sunny today.' }, + ], + } + ); + + t.deepEqual(scores, [0.9, 0.1]); + t.is(capturedProtocol, 'openai_chat'); + t.deepEqual(capturedRequest, { + model: 'gpt-5.2', + query: 'programming', + candidates: [ + { id: 'react', text: 'React is a UI library.' }, + { id: 'weather', text: 'The weather is sunny today.' }, + ], + }); +}); diff --git a/packages/backend/server/src/__tests__/copilot/provider-registry.spec.ts b/packages/backend/server/src/__tests__/copilot/provider-registry.spec.ts new file mode 100644 index 0000000000..6412c42abe --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/provider-registry.spec.ts @@ -0,0 +1,168 @@ +import test from 'ava'; + +import { + buildProviderRegistry, + resolveModel, + stripProviderPrefix, +} from '../../plugins/copilot/providers/provider-registry'; +import { + CopilotProviderType, + ModelOutputType, +} from '../../plugins/copilot/providers/types'; + +test('buildProviderRegistry should keep explicit profile over legacy compatibility profile', t => { + const registry = buildProviderRegistry({ + profiles: [ + { + id: 'openai-default', + type: CopilotProviderType.OpenAI, + priority: 100, + config: { apiKey: 'new' }, + }, + ], + openai: { apiKey: 'legacy' }, + }); + + const profile = registry.profiles.get('openai-default'); + t.truthy(profile); + t.deepEqual(profile?.config, { apiKey: 'new' }); +}); + +test('buildProviderRegistry should reject duplicated profile ids', t => { + const error = t.throws(() => + buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '1' }, + }, + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '2' }, + }, + ], + }) + ) as Error; + + t.truthy(error); + t.regex(error.message, /Duplicated copilot provider profile id/); +}); + +test('buildProviderRegistry should reject defaults that reference unknown providers', t => { + const error = t.throws(() => + buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '1' }, + }, + ], + defaults: { + fallback: 'unknown-provider', + }, + }) + ) as Error; + + t.truthy(error); + t.regex(error.message, /defaults references unknown providerId/); +}); + +test('resolveModel should support explicit provider prefix and keep slash models untouched', t => { + const registry = buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '1' }, + }, + { + id: 'fal-main', + type: CopilotProviderType.FAL, + config: { apiKey: '2' }, + }, + ], + }); + + const prefixed = resolveModel({ + registry, + modelId: 'openai-main/gpt-5-mini', + }); + t.deepEqual(prefixed, { + rawModelId: 'openai-main/gpt-5-mini', + modelId: 'gpt-5-mini', + explicitProviderId: 'openai-main', + candidateProviderIds: ['openai-main'], + }); + + const slashModel = resolveModel({ + registry, + modelId: 'lora/image-to-image', + }); + t.is(slashModel.modelId, 'lora/image-to-image'); + t.false(slashModel.candidateProviderIds.includes('lora')); +}); + +test('resolveModel should follow defaults -> fallback -> order and apply filters', t => { + const registry = buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + priority: 10, + config: { apiKey: '1' }, + }, + { + id: 'anthropic-main', + type: CopilotProviderType.Anthropic, + priority: 5, + config: { apiKey: '2' }, + }, + { + id: 'fal-main', + type: CopilotProviderType.FAL, + priority: 1, + config: { apiKey: '3' }, + }, + ], + defaults: { + [ModelOutputType.Text]: 'anthropic-main', + fallback: 'openai-main', + }, + }); + + const routed = resolveModel({ + registry, + outputType: ModelOutputType.Text, + preferredProviderIds: ['openai-main', 'fal-main'], + }); + + t.deepEqual(routed.candidateProviderIds, ['openai-main', 'fal-main']); +}); + +test('stripProviderPrefix should only strip matched provider prefix', t => { + const registry = buildProviderRegistry({ + profiles: [ + { + id: 'openai-main', + type: CopilotProviderType.OpenAI, + config: { apiKey: '1' }, + }, + ], + }); + + t.is( + stripProviderPrefix(registry, 'openai-main', 'openai-main/gpt-5-mini'), + 'gpt-5-mini' + ); + t.is( + stripProviderPrefix(registry, 'openai-main', 'another-main/gpt-5-mini'), + 'another-main/gpt-5-mini' + ); + t.is( + stripProviderPrefix(registry, 'openai-main', 'gpt-5-mini'), + 'gpt-5-mini' + ); +}); diff --git a/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts b/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts new file mode 100644 index 0000000000..846d7d96eb --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/tool-call-loop.spec.ts @@ -0,0 +1,288 @@ +import test from 'ava'; +import { z } from 'zod'; + +import { NativeLlmRequest, NativeLlmStreamEvent } from '../../native'; +import { + ToolCallAccumulator, + ToolCallLoop, + ToolSchemaExtractor, +} from '../../plugins/copilot/providers/loop'; + +test('ToolCallAccumulator should merge deltas and complete tool call', t => { + const accumulator = new ToolCallAccumulator(); + + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":"', + }); + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + arguments_delta: 'a1"}', + }); + + const completed = accumulator.complete({ + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }); + + t.deepEqual(completed, { + id: 'call_1', + name: 'doc_read', + args: { doc_id: 'a1' }, + rawArgumentsText: '{"doc_id":"a1"}', + thought: undefined, + }); +}); + +test('ToolCallAccumulator should preserve invalid JSON instead of swallowing it', t => { + const accumulator = new ToolCallAccumulator(); + + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":', + }); + + const pending = accumulator.drainPending(); + + t.is(pending.length, 1); + t.deepEqual(pending[0]?.id, 'call_1'); + t.deepEqual(pending[0]?.name, 'doc_read'); + t.deepEqual(pending[0]?.args, {}); + t.is(pending[0]?.rawArgumentsText, '{"doc_id":'); + t.truthy(pending[0]?.argumentParseError); +}); + +test('ToolCallAccumulator should prefer native canonical tool arguments metadata', t => { + const accumulator = new ToolCallAccumulator(); + + accumulator.feedDelta({ + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"stale":true}', + }); + + const completed = accumulator.complete({ + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: {}, + arguments_text: '{"doc_id":"a1"}', + arguments_error: 'invalid json', + }); + + t.deepEqual(completed, { + id: 'call_1', + name: 'doc_read', + args: {}, + rawArgumentsText: '{"doc_id":"a1"}', + argumentParseError: 'invalid json', + thought: undefined, + }); +}); + +test('ToolSchemaExtractor should convert zod schema to json schema', t => { + const toolSet = { + doc_read: { + description: 'Read doc', + inputSchema: z.object({ + doc_id: z.string(), + limit: z.number().optional(), + }), + execute: async () => ({}), + }, + }; + + const extracted = ToolSchemaExtractor.extract(toolSet); + + t.deepEqual(extracted, [ + { + name: 'doc_read', + description: 'Read doc', + parameters: { + type: 'object', + properties: { + doc_id: { type: 'string' }, + limit: { type: 'number' }, + }, + additionalProperties: false, + required: ['doc_id'], + }, + }, + ]); +}); + +test('ToolCallLoop should execute tool call and continue to next round', async t => { + const dispatchRequests: NativeLlmRequest[] = []; + const originalMessages = [{ role: 'user', content: 'read doc' }] as const; + const signal = new AbortController().signal; + + const dispatch = (request: NativeLlmRequest) => { + dispatchRequests.push(request); + const round = dispatchRequests.length; + + return (async function* (): AsyncIterableIterator { + if (round === 1) { + yield { + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":"a1"}', + }; + yield { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + }; + yield { type: 'done', finish_reason: 'tool_calls' }; + return; + } + + yield { type: 'text_delta', text: 'done' }; + yield { type: 'done', finish_reason: 'stop' }; + })(); + }; + + let executedArgs: Record | null = null; + let executedMessages: unknown; + let executedSignal: AbortSignal | undefined; + const loop = new ToolCallLoop( + dispatch, + { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async (args, options) => { + executedArgs = args; + executedMessages = options.messages; + executedSignal = options.signal; + return { markdown: '# doc' }; + }, + }, + }, + 4 + ); + + const events: NativeLlmStreamEvent[] = []; + for await (const event of loop.run( + { + model: 'gpt-5-mini', + stream: true, + messages: [ + { role: 'user', content: [{ type: 'text', text: 'read doc' }] }, + ], + }, + signal, + [...originalMessages] + )) { + events.push(event); + } + + t.deepEqual(executedArgs, { doc_id: 'a1' }); + t.deepEqual(executedMessages, originalMessages); + t.is(executedSignal, signal); + t.true( + dispatchRequests[1]?.messages.some(message => message.role === 'tool') + ); + t.deepEqual(dispatchRequests[1]?.messages[1]?.content, [ + { + type: 'tool_call', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + arguments_text: '{"doc_id":"a1"}', + arguments_error: undefined, + thought: undefined, + }, + ]); + t.deepEqual(dispatchRequests[1]?.messages[2]?.content, [ + { + type: 'tool_result', + call_id: 'call_1', + name: 'doc_read', + arguments: { doc_id: 'a1' }, + arguments_text: '{"doc_id":"a1"}', + arguments_error: undefined, + output: { markdown: '# doc' }, + is_error: undefined, + }, + ]); + t.deepEqual( + events.map(event => event.type), + ['tool_call', 'tool_result', 'text_delta', 'done'] + ); +}); + +test('ToolCallLoop should surface invalid JSON as tool error without executing', async t => { + let executed = false; + let round = 0; + const loop = new ToolCallLoop( + request => { + round += 1; + const hasToolResult = request.messages.some( + message => message.role === 'tool' + ); + return (async function* (): AsyncIterableIterator { + if (!hasToolResult && round === 1) { + yield { + type: 'tool_call_delta', + call_id: 'call_1', + name: 'doc_read', + arguments_delta: '{"doc_id":', + }; + yield { type: 'done', finish_reason: 'tool_calls' }; + return; + } + + yield { type: 'done', finish_reason: 'stop' }; + })(); + }, + { + doc_read: { + inputSchema: z.object({ doc_id: z.string() }), + execute: async () => { + executed = true; + return { markdown: '# doc' }; + }, + }, + }, + 2 + ); + + const events: NativeLlmStreamEvent[] = []; + for await (const event of loop.run({ + model: 'gpt-5-mini', + stream: true, + messages: [{ role: 'user', content: [{ type: 'text', text: 'read doc' }] }], + })) { + events.push(event); + } + + t.false(executed); + t.true(events[0]?.type === 'tool_result'); + t.deepEqual(events[0], { + type: 'tool_result', + call_id: 'call_1', + name: 'doc_read', + arguments: {}, + arguments_text: '{"doc_id":', + arguments_error: + events[0]?.type === 'tool_result' ? events[0].arguments_error : undefined, + output: { + message: 'Invalid tool arguments JSON', + rawArguments: '{"doc_id":', + error: + events[0]?.type === 'tool_result' + ? events[0].arguments_error + : undefined, + }, + is_error: true, + }); +}); diff --git a/packages/backend/server/src/__tests__/copilot/utils.spec.ts b/packages/backend/server/src/__tests__/copilot/utils.spec.ts new file mode 100644 index 0000000000..94844cc176 --- /dev/null +++ b/packages/backend/server/src/__tests__/copilot/utils.spec.ts @@ -0,0 +1,46 @@ +import test from 'ava'; + +import { CitationFootnoteFormatter } from '../../plugins/copilot/providers/utils'; + +test('CitationFootnoteFormatter should format sorted footnotes from citation events', t => { + const formatter = new CitationFootnoteFormatter(); + + formatter.consume({ + type: 'citation', + index: 2, + url: 'https://example.com/b', + }); + formatter.consume({ + type: 'citation', + index: 1, + url: 'https://example.com/a', + }); + + t.is( + formatter.end(), + [ + '[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fa"}', + '[^2]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fb"}', + ].join('\n') + ); +}); + +test('CitationFootnoteFormatter should overwrite duplicated index with latest url', t => { + const formatter = new CitationFootnoteFormatter(); + + formatter.consume({ + type: 'citation', + index: 1, + url: 'https://example.com/old', + }); + formatter.consume({ + type: 'citation', + index: 1, + url: 'https://example.com/new', + }); + + t.is( + formatter.end(), + '[^1]: {"type":"url","url":"https%3A%2F%2Fexample.com%2Fnew"}' + ); +}); diff --git a/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.md b/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.md index f05073bd1a..e5b8f7fee8 100644 --- a/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.md +++ b/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.md @@ -9,6 +9,16 @@ Generated by [AVA](https://avajs.dev). > Snapshot 1 { + knownUnsupportedBlocks: [ + 'RX4CG2zsBk:affine:note', + 'S1mkc8zUoU:affine:note', + 'yGlBdshAqN:affine:note', + '6lDiuDqZGL:affine:note', + 'cauvaHOQmh:affine:note', + '2jwCeO8Yot:affine:note', + 'c9MF_JiRgx:affine:note', + '6x7ALjUDjj:affine:surface', + ], markdown: `AFFiNE is an open source all in one workspace, an operating system for all the building blocks of your team wiki, knowledge management and digital assets and a better alternative to Notion and Miro.␊ ␊ ␊ @@ -70,35 +80,9 @@ Generated by [AVA](https://avajs.dev). ␊ ␊ ␊ - ␊ - [](Bookmark,https://affine.pro/)␊ - ␊ - ␊ - [](Bookmark,https://www.youtube.com/@affinepro)␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ `, title: 'Write, Draw, Plan all at Once.', + unknownBlocks: [], } ## should get doc markdown return null when doc not exists diff --git a/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.snap b/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.snap index 7bd07afacd..702fef10e3 100644 Binary files a/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.snap and b/packages/backend/server/src/__tests__/e2e/doc-service/__snapshots__/controller.spec.ts.snap differ diff --git a/packages/backend/server/src/__tests__/mails.spec.ts b/packages/backend/server/src/__tests__/mails.spec.ts index 7ed3bb8893..c0ee374769 100644 --- a/packages/backend/server/src/__tests__/mails.spec.ts +++ b/packages/backend/server/src/__tests__/mails.spec.ts @@ -1,5 +1,6 @@ import test from 'ava'; +import { normalizeSMTPHeloHostname } from '../core/mail/utils'; import { Renderers } from '../mails'; import { TEST_DOC, TEST_USER } from '../mails/common'; @@ -21,3 +22,22 @@ test('should render mention email with empty doc title', async t => { }); t.snapshot(content.html, content.subject); }); + +test('should normalize valid SMTP HELO hostnames', t => { + t.is(normalizeSMTPHeloHostname('mail.example.com'), 'mail.example.com'); + t.is(normalizeSMTPHeloHostname(' localhost '), 'localhost'); + t.is(normalizeSMTPHeloHostname('[127.0.0.1]'), '[127.0.0.1]'); + t.is(normalizeSMTPHeloHostname('[IPv6:2001:db8::1]'), '[IPv6:2001:db8::1]'); +}); + +test('should reject invalid SMTP HELO hostnames', t => { + t.is(normalizeSMTPHeloHostname(''), undefined); + t.is(normalizeSMTPHeloHostname(' '), undefined); + t.is(normalizeSMTPHeloHostname('AFFiNE Server'), undefined); + t.is(normalizeSMTPHeloHostname('-example.com'), undefined); + t.is(normalizeSMTPHeloHostname('example-.com'), undefined); + t.is(normalizeSMTPHeloHostname('example..com'), undefined); + t.is(normalizeSMTPHeloHostname('[bad host]'), undefined); + t.is(normalizeSMTPHeloHostname('[foo]'), undefined); + t.is(normalizeSMTPHeloHostname('[IPv6:foo]'), undefined); +}); diff --git a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts index a092329df9..b43dae79a4 100644 --- a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts +++ b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts @@ -33,39 +33,12 @@ export class MockCopilotProvider extends OpenAIProvider { id: 'test-image', capabilities: [ { - input: [ModelInputType.Text], + input: [ModelInputType.Text, ModelInputType.Image], output: [ModelOutputType.Image], defaultForOutputType: true, }, ], }, - { - id: 'gpt-4o', - capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, - ], - }, - { - id: 'gpt-4o-2024-08-06', - capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, - ], - }, - { - id: 'gpt-4.1-2025-04-14', - capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, - ], - }, { id: 'gpt-5', capabilities: [ @@ -97,6 +70,19 @@ export class MockCopilotProvider extends OpenAIProvider { }, ], }, + { + id: 'gpt-5-nano', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ], + }, + ], + }, { id: 'gpt-image-1', capabilities: [ @@ -133,6 +119,23 @@ export class MockCopilotProvider extends OpenAIProvider { }, ], }, + { + id: 'gemini-3.1-pro-preview', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ], + }, + ], + }, ]; override async text( diff --git a/packages/backend/server/src/__tests__/mocks/eventbus.mock.ts b/packages/backend/server/src/__tests__/mocks/eventbus.mock.ts index c9f887b22f..9a09b59fc5 100644 --- a/packages/backend/server/src/__tests__/mocks/eventbus.mock.ts +++ b/packages/backend/server/src/__tests__/mocks/eventbus.mock.ts @@ -8,6 +8,7 @@ export class MockEventBus { emit = this.stub.emitAsync; emitAsync = this.stub.emitAsync; + emitDetached = this.stub.emitAsync; broadcast = this.stub.broadcast; last( diff --git a/packages/backend/server/src/__tests__/native.spec.ts b/packages/backend/server/src/__tests__/native.spec.ts new file mode 100644 index 0000000000..bc088cfaa4 --- /dev/null +++ b/packages/backend/server/src/__tests__/native.spec.ts @@ -0,0 +1,82 @@ +import test from 'ava'; + +import { NativeStreamAdapter } from '../native'; + +test('NativeStreamAdapter should support buffered and awaited consumption', async t => { + const adapter = new NativeStreamAdapter(undefined); + + adapter.push(1); + const first = await adapter.next(); + t.deepEqual(first, { value: 1, done: false }); + + const pending = adapter.next(); + adapter.push(2); + const second = await pending; + t.deepEqual(second, { value: 2, done: false }); + + adapter.push(null); + const done = await adapter.next(); + t.true(done.done); +}); + +test('NativeStreamAdapter return should abort handle and end iteration', async t => { + let abortCount = 0; + const adapter = new NativeStreamAdapter({ + abort: () => { + abortCount += 1; + }, + }); + + const ended = await adapter.return(); + t.is(abortCount, 1); + t.true(ended.done); + + const secondReturn = await adapter.return(); + t.true(secondReturn.done); + t.is(abortCount, 1); + + const next = await adapter.next(); + t.true(next.done); +}); + +test('NativeStreamAdapter should abort when AbortSignal is triggered', async t => { + let abortCount = 0; + const controller = new AbortController(); + const adapter = new NativeStreamAdapter( + { + abort: () => { + abortCount += 1; + }, + }, + controller.signal + ); + + const pending = adapter.next(); + controller.abort(); + const done = await pending; + t.true(done.done); + t.is(abortCount, 1); +}); + +test('NativeStreamAdapter should end immediately for pre-aborted signal', async t => { + let abortCount = 0; + const controller = new AbortController(); + controller.abort(); + + const adapter = new NativeStreamAdapter( + { + abort: () => { + abortCount += 1; + }, + }, + controller.signal + ); + + const next = await adapter.next(); + t.true(next.done); + t.is(abortCount, 1); + + adapter.push(1); + const stillDone = await adapter.next(); + t.true(stillDone.done); +}); diff --git a/packages/backend/server/src/__tests__/user/user.e2e.ts b/packages/backend/server/src/__tests__/user/user.e2e.ts index c6490d3d53..8e31c9840a 100644 --- a/packages/backend/server/src/__tests__/user/user.e2e.ts +++ b/packages/backend/server/src/__tests__/user/user.e2e.ts @@ -4,9 +4,9 @@ import type { TestFn } from 'ava'; import ava from 'ava'; import { + createBmp, createTestingApp, getPublicUserById, - smallestGif, smallestPng, TestingApp, updateAvatar, @@ -40,7 +40,10 @@ test('should be able to upload user avatar', async t => { const avatarRes = await app.GET(new URL(avatarUrl).pathname); - t.deepEqual(avatarRes.body, avatar); + t.true(avatarRes.headers['content-type'].startsWith('image/webp')); + t.notDeepEqual(avatarRes.body, avatar); + t.is(avatarRes.body.subarray(0, 4).toString('ascii'), 'RIFF'); + t.is(avatarRes.body.subarray(8, 12).toString('ascii'), 'WEBP'); }); test('should be able to update user avatar, and invalidate old avatar url', async t => { @@ -54,9 +57,7 @@ test('should be able to update user avatar, and invalidate old avatar url', asyn const oldAvatarUrl = res.body.data.uploadAvatar.avatarUrl; - const newAvatar = await fetch(smallestGif) - .then(res => res.arrayBuffer()) - .then(b => Buffer.from(b)); + const newAvatar = createBmp(32, 32); res = await updateAvatar(app, newAvatar); const newAvatarUrl = res.body.data.uploadAvatar.avatarUrl; @@ -66,7 +67,46 @@ test('should be able to update user avatar, and invalidate old avatar url', asyn t.is(avatarRes.status, 404); const newAvatarRes = await app.GET(new URL(newAvatarUrl).pathname); - t.deepEqual(newAvatarRes.body, newAvatar); + t.true(newAvatarRes.headers['content-type'].startsWith('image/webp')); + t.notDeepEqual(newAvatarRes.body, newAvatar); + t.is(newAvatarRes.body.subarray(0, 4).toString('ascii'), 'RIFF'); + t.is(newAvatarRes.body.subarray(8, 12).toString('ascii'), 'WEBP'); +}); + +test('should accept avatar uploads up to 5MB after conversion', async t => { + const { app } = t.context; + + await app.signup(); + const avatar = createBmp(1024, 1024); + t.true(avatar.length > 500 * 1024); + t.true(avatar.length < 5 * 1024 * 1024); + + const res = await updateAvatar(app, avatar, { + filename: 'large.bmp', + contentType: 'image/bmp', + }); + + t.is(res.status, 200); + const avatarUrl = res.body.data.uploadAvatar.avatarUrl; + const avatarRes = await app.GET(new URL(avatarUrl).pathname); + + t.true(avatarRes.headers['content-type'].startsWith('image/webp')); +}); + +test('should reject unsupported vector avatars', async t => { + const { app } = t.context; + + await app.signup(); + const avatar = Buffer.from( + '' + ); + const res = await updateAvatar(app, avatar, { + filename: 'avatar.svg', + contentType: 'image/svg+xml', + }); + + t.is(res.status, 200); + t.is(res.body.errors[0].message, 'Image format not supported: image/svg+xml'); }); test('should be able to get public user by id', async t => { diff --git a/packages/backend/server/src/__tests__/utils/blobs.ts b/packages/backend/server/src/__tests__/utils/blobs.ts index d4e2f71029..ccca87d614 100644 --- a/packages/backend/server/src/__tests__/utils/blobs.ts +++ b/packages/backend/server/src/__tests__/utils/blobs.ts @@ -7,6 +7,35 @@ export const smallestPng = 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII'; export const smallestGif = 'data:image/gif;base64,R0lGODlhAQABAAAAACw='; +export function createBmp(width: number, height: number) { + const rowSize = Math.ceil((width * 3) / 4) * 4; + const pixelDataSize = rowSize * height; + const fileSize = 54 + pixelDataSize; + const buffer = Buffer.alloc(fileSize); + + buffer.write('BM', 0, 'ascii'); + buffer.writeUInt32LE(fileSize, 2); + buffer.writeUInt32LE(54, 10); + buffer.writeUInt32LE(40, 14); + buffer.writeInt32LE(width, 18); + buffer.writeInt32LE(height, 22); + buffer.writeUInt16LE(1, 26); + buffer.writeUInt16LE(24, 28); + buffer.writeUInt32LE(pixelDataSize, 34); + + for (let y = 0; y < height; y++) { + const rowOffset = 54 + y * rowSize; + for (let x = 0; x < width; x++) { + const pixelOffset = rowOffset + x * 3; + buffer[pixelOffset] = 0x33; + buffer[pixelOffset + 1] = 0x66; + buffer[pixelOffset + 2] = 0x99; + } + } + + return buffer; +} + export async function listBlobs( app: TestingApp, workspaceId: string diff --git a/packages/backend/server/src/__tests__/utils/copilot.ts b/packages/backend/server/src/__tests__/utils/copilot.ts index 9e738ba43d..59aacb7b9e 100644 --- a/packages/backend/server/src/__tests__/utils/copilot.ts +++ b/packages/backend/server/src/__tests__/utils/copilot.ts @@ -629,14 +629,35 @@ export async function chatWithText( prefix = '', retry?: boolean ): Promise { + const endpoint = prefix || '/stream'; const query = messageId ? `?messageId=${messageId}` + (retry ? '&retry=true' : '') : ''; const res = await app - .GET(`/api/copilot/chat/${sessionId}${prefix}${query}`) + .GET(`/api/copilot/chat/${sessionId}${endpoint}${query}`) .expect(200); - return res.text; + if (prefix) { + return res.text; + } + + const events = sse2array(res.text); + const errorEvent = events.find(event => event.event === 'error'); + if (errorEvent?.data) { + let message = errorEvent.data; + try { + const parsed = JSON.parse(errorEvent.data); + message = parsed.message || message; + } catch { + // noop: keep raw error data + } + throw new Error(message); + } + + return events + .filter(event => event.event === 'message') + .map(event => event.data ?? '') + .join(''); } export async function chatWithTextStream( diff --git a/packages/backend/server/src/__tests__/utils/user.ts b/packages/backend/server/src/__tests__/utils/user.ts index 28811584bc..bb06d1dc04 100644 --- a/packages/backend/server/src/__tests__/utils/user.ts +++ b/packages/backend/server/src/__tests__/utils/user.ts @@ -121,7 +121,11 @@ export async function deleteAccount(app: TestingApp) { return res.deleteAccount.success; } -export async function updateAvatar(app: TestingApp, avatar: Buffer) { +export async function updateAvatar( + app: TestingApp, + avatar: Buffer, + options: { filename?: string; contentType?: string } = {} +) { return app .POST('/graphql') .field( @@ -138,7 +142,7 @@ export async function updateAvatar(app: TestingApp, avatar: Buffer) { ) .field('map', JSON.stringify({ '0': ['variables.avatar'] })) .attach('0', avatar, { - filename: 'test.png', - contentType: 'image/png', + filename: options.filename || 'test.png', + contentType: options.contentType || 'image/png', }); } diff --git a/packages/backend/server/src/__tests__/worker.e2e.ts b/packages/backend/server/src/__tests__/worker.e2e.ts index 1ebe09006e..d111a03d09 100644 --- a/packages/backend/server/src/__tests__/worker.e2e.ts +++ b/packages/backend/server/src/__tests__/worker.e2e.ts @@ -38,8 +38,11 @@ test.before(async t => { t.context.app = app; }); -test.after.always(async t => { +test.afterEach.always(() => { Sinon.restore(); +}); + +test.after.always(async t => { __resetDnsLookupForTests(); await t.context.app.close(); }); @@ -80,6 +83,7 @@ const assertAndSnapshotRaw = async ( test('should proxy image', async t => { const assertAndSnapshot = assertAndSnapshotRaw.bind(null, t); + const imageUrl = `http://example.com/image-${Date.now()}.png`; await assertAndSnapshot( '/api/worker/image-proxy', @@ -105,7 +109,7 @@ test('should proxy image', async t => { { await assertAndSnapshot( - '/api/worker/image-proxy?url=http://example.com/image.png', + `/api/worker/image-proxy?url=${imageUrl}`, 'should return 400 if origin and referer are missing', { status: 400, origin: null, referer: null } ); @@ -113,14 +117,17 @@ test('should proxy image', async t => { { await assertAndSnapshot( - '/api/worker/image-proxy?url=http://example.com/image.png', + `/api/worker/image-proxy?url=${imageUrl}`, 'should return 400 for invalid origin header', { status: 400, origin: 'http://invalid.com' } ); } { - const fakeBuffer = Buffer.from('fake image'); + const fakeBuffer = Buffer.from( + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+jfJ8AAAAASUVORK5CYII=', + 'base64' + ); const fakeResponse = new Response(fakeBuffer, { status: 200, headers: { @@ -130,13 +137,14 @@ test('should proxy image', async t => { }); const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeResponse); - - await assertAndSnapshot( - '/api/worker/image-proxy?url=http://example.com/image.png', - 'should return image buffer' - ); - - fetchSpy.restore(); + try { + await assertAndSnapshot( + `/api/worker/image-proxy?url=${imageUrl}`, + 'should return image buffer' + ); + } finally { + fetchSpy.restore(); + } } }); @@ -200,18 +208,19 @@ test('should preview link', async t => { }); const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML); - - await assertAndSnapshot( - '/api/worker/link-preview', - 'should process a valid external URL and return link preview data', - { - status: 200, - method: 'POST', - body: { url: 'http://external.com/page' }, - } - ); - - fetchSpy.restore(); + try { + await assertAndSnapshot( + '/api/worker/link-preview', + 'should process a valid external URL and return link preview data', + { + status: 200, + method: 'POST', + body: { url: 'http://external.com/page' }, + } + ); + } finally { + fetchSpy.restore(); + } } { @@ -251,18 +260,19 @@ test('should preview link', async t => { }); const fetchSpy = Sinon.stub(global, 'fetch').resolves(fakeHTML); - - await assertAndSnapshot( - '/api/worker/link-preview', - 'should decode HTML content with charset', - { - status: 200, - method: 'POST', - body: { url: `http://example.com/${charset}` }, - } - ); - - fetchSpy.restore(); + try { + await assertAndSnapshot( + '/api/worker/link-preview', + 'should decode HTML content with charset', + { + status: 200, + method: 'POST', + body: { url: `http://example.com/${charset}` }, + } + ); + } finally { + fetchSpy.restore(); + } } } }); diff --git a/packages/backend/server/src/base/error/def.ts b/packages/backend/server/src/base/error/def.ts index 768eaa7884..be7682d564 100644 --- a/packages/backend/server/src/base/error/def.ts +++ b/packages/backend/server/src/base/error/def.ts @@ -301,6 +301,11 @@ export const USER_FRIENDLY_ERRORS = { }, // Input errors + image_format_not_supported: { + type: 'invalid_input', + args: { format: 'string' }, + message: ({ format }) => `Image format not supported: ${format}`, + }, query_too_long: { type: 'invalid_input', args: { max: 'number' }, diff --git a/packages/backend/server/src/base/error/errors.gen.ts b/packages/backend/server/src/base/error/errors.gen.ts index afbcb63df4..497b92c4ec 100644 --- a/packages/backend/server/src/base/error/errors.gen.ts +++ b/packages/backend/server/src/base/error/errors.gen.ts @@ -82,6 +82,16 @@ export class EmailServiceNotConfigured extends UserFriendlyError { } } @ObjectType() +class ImageFormatNotSupportedDataType { + @Field() format!: string +} + +export class ImageFormatNotSupported extends UserFriendlyError { + constructor(args: ImageFormatNotSupportedDataType, message?: string | ((args: ImageFormatNotSupportedDataType) => string)) { + super('invalid_input', 'image_format_not_supported', message, args); + } +} +@ObjectType() class QueryTooLongDataType { @Field() max!: number } @@ -1155,6 +1165,7 @@ export enum ErrorNames { SSRF_BLOCKED_ERROR, RESPONSE_TOO_LARGE_ERROR, EMAIL_SERVICE_NOT_CONFIGURED, + IMAGE_FORMAT_NOT_SUPPORTED, QUERY_TOO_LONG, VALIDATION_ERROR, USER_NOT_FOUND, @@ -1297,5 +1308,5 @@ registerEnumType(ErrorNames, { export const ErrorDataUnionType = createUnionType({ name: 'ErrorDataUnion', types: () => - [GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const, + [GraphqlBadRequestDataType, HttpRequestErrorDataType, SsrfBlockedErrorDataType, ResponseTooLargeErrorDataType, ImageFormatNotSupportedDataType, QueryTooLongDataType, ValidationErrorDataType, WrongSignInCredentialsDataType, UnknownOauthProviderDataType, InvalidOauthCallbackCodeDataType, MissingOauthQueryParameterDataType, InvalidOauthResponseDataType, InvalidEmailDataType, InvalidPasswordLengthDataType, WorkspacePermissionNotFoundDataType, SpaceNotFoundDataType, MemberNotFoundInSpaceDataType, NotInSpaceDataType, AlreadyInSpaceDataType, SpaceAccessDeniedDataType, SpaceOwnerNotFoundDataType, SpaceShouldHaveOnlyOneOwnerDataType, DocNotFoundDataType, DocActionDeniedDataType, DocUpdateBlockedDataType, VersionRejectedDataType, InvalidHistoryTimestampDataType, DocHistoryNotFoundDataType, BlobNotFoundDataType, ExpectToGrantDocUserRolesDataType, ExpectToRevokeDocUserRolesDataType, ExpectToUpdateDocUserRoleDataType, NoMoreSeatDataType, UnsupportedSubscriptionPlanDataType, SubscriptionAlreadyExistsDataType, SubscriptionNotExistsDataType, SameSubscriptionRecurringDataType, SubscriptionPlanNotFoundDataType, CalendarProviderRequestErrorDataType, NoCopilotProviderAvailableDataType, CopilotFailedToGenerateEmbeddingDataType, CopilotDocNotFoundDataType, CopilotMessageNotFoundDataType, CopilotPromptNotFoundDataType, CopilotProviderNotSupportedDataType, CopilotProviderSideErrorDataType, CopilotInvalidContextDataType, CopilotContextFileNotSupportedDataType, CopilotFailedToModifyContextDataType, CopilotFailedToMatchContextDataType, CopilotFailedToMatchGlobalContextDataType, CopilotFailedToAddWorkspaceFileEmbeddingDataType, RuntimeConfigNotFoundDataType, InvalidRuntimeConfigTypeDataType, InvalidLicenseToActivateDataType, InvalidLicenseUpdateParamsDataType, UnsupportedClientVersionDataType, MentionUserDocAccessDeniedDataType, InvalidAppConfigDataType, InvalidAppConfigInputDataType, InvalidSearchProviderRequestDataType, InvalidIndexerInputDataType] as const, }); diff --git a/packages/backend/server/src/base/event/eventbus.ts b/packages/backend/server/src/base/event/eventbus.ts index 7d8c46ff61..cb90ba3ad5 100644 --- a/packages/backend/server/src/base/event/eventbus.ts +++ b/packages/backend/server/src/base/event/eventbus.ts @@ -88,12 +88,21 @@ export class EventBus emit(event: T, payload: Events[T]) { this.logger.debug(`Dispatch event: ${event}`); - // NOTE(@forehalo): - // Because all event handlers are wrapped in promisified metrics and cls context, they will always run in standalone tick. - // In which way, if handler throws, an unhandled rejection will be triggered and end up with process exiting. - // So we catch it here with `emitAsync` - this.emitter.emitAsync(event, payload).catch(e => { - this.emitter.emit('error', { event, payload, error: e }); + this.dispatchAsync(event, payload); + + return true; + } + + /** + * Emit event in detached cls context to avoid inheriting current transaction. + */ + emitDetached(event: T, payload: Events[T]) { + this.logger.debug(`Dispatch event: ${event} (detached)`); + + const requestId = this.cls.getId(); + this.cls.run({ ifNested: 'override' }, () => { + this.cls.set(CLS_ID, requestId ?? genRequestId('event')); + this.dispatchAsync(event, payload); }); return true; @@ -166,6 +175,16 @@ export class EventBus return this.emitter.waitFor(name, timeout); } + private dispatchAsync(event: T, payload: Events[T]) { + // NOTE: + // Because all event handlers are wrapped in promisified metrics and cls context, they will always run in standalone tick. + // In which way, if handler throws, an unhandled rejection will be triggered and end up with process exiting. + // So we catch it here with `emitAsync` + this.emitter.emitAsync(event, payload).catch(e => { + this.emitter.emit('error', { event, payload, error: e }); + }); + } + private readonly bindEventHandlers = once(() => { this.scanner.scan().forEach(({ event, handler, opts }) => { this.on(event, handler, opts); diff --git a/packages/backend/server/src/core/doc-renderer/__tests__/controller.spec.ts b/packages/backend/server/src/core/doc-renderer/__tests__/controller.spec.ts index 56d75b112f..b76fd06894 100644 --- a/packages/backend/server/src/core/doc-renderer/__tests__/controller.spec.ts +++ b/packages/backend/server/src/core/doc-renderer/__tests__/controller.spec.ts @@ -129,6 +129,8 @@ test('should return markdown content and skip page view when accept is text/mark const markdown = Sinon.stub(docReader, 'getDocMarkdown').resolves({ title: 'markdown-doc', markdown: '# markdown-doc', + knownUnsupportedBlocks: [], + unknownBlocks: [], }); const docContent = Sinon.stub(docReader, 'getDocContent'); const record = Sinon.stub( diff --git a/packages/backend/server/src/core/doc-service/__tests__/controller.spec.ts b/packages/backend/server/src/core/doc-service/__tests__/controller.spec.ts index b03b6e10bb..c8163384d8 100644 --- a/packages/backend/server/src/core/doc-service/__tests__/controller.spec.ts +++ b/packages/backend/server/src/core/doc-service/__tests__/controller.spec.ts @@ -402,6 +402,8 @@ test('should get doc markdown in json format', async t => { return { title: 'test title', markdown: 'test markdown', + knownUnsupportedBlocks: [], + unknownBlocks: [], }; }); @@ -418,6 +420,8 @@ test('should get doc markdown in json format', async t => { .expect({ title: 'test title', markdown: 'test markdown', + knownUnsupportedBlocks: [], + unknownBlocks: [], }); t.pass(); }); diff --git a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.md b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.md index a2672336e2..c0420af59c 100644 --- a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.md +++ b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.md @@ -9,6 +9,16 @@ Generated by [AVA](https://avajs.dev). > Snapshot 1 { + knownUnsupportedBlocks: [ + 'RX4CG2zsBk:affine:note', + 'S1mkc8zUoU:affine:note', + 'yGlBdshAqN:affine:note', + '6lDiuDqZGL:affine:note', + 'cauvaHOQmh:affine:note', + '2jwCeO8Yot:affine:note', + 'c9MF_JiRgx:affine:note', + '6x7ALjUDjj:affine:surface', + ], markdown: `AFFiNE is an open source all in one workspace, an operating system for all the building blocks of your team wiki, knowledge management and digital assets and a better alternative to Notion and Miro.␊ ␊ ␊ @@ -70,33 +80,7 @@ Generated by [AVA](https://avajs.dev). ␊ ␊ ␊ - ␊ - [](Bookmark,https://affine.pro/)␊ - ␊ - ␊ - [](Bookmark,https://www.youtube.com/@affinepro)␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ `, title: 'Write, Draw, Plan all at Once.', + unknownBlocks: [], } diff --git a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.snap b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.snap index c2e0c3a45c..dc0fdb3923 100644 Binary files a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.snap and b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-database.spec.ts.snap differ diff --git a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.md b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.md index 32460b7781..898ac18e77 100644 --- a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.md +++ b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.md @@ -9,6 +9,16 @@ Generated by [AVA](https://avajs.dev). > Snapshot 1 { + knownUnsupportedBlocks: [ + 'RX4CG2zsBk:affine:note', + 'S1mkc8zUoU:affine:note', + 'yGlBdshAqN:affine:note', + '6lDiuDqZGL:affine:note', + 'cauvaHOQmh:affine:note', + '2jwCeO8Yot:affine:note', + 'c9MF_JiRgx:affine:note', + '6x7ALjUDjj:affine:surface', + ], markdown: `AFFiNE is an open source all in one workspace, an operating system for all the building blocks of your team wiki, knowledge management and digital assets and a better alternative to Notion and Miro.␊ ␊ ␊ @@ -70,33 +80,7 @@ Generated by [AVA](https://avajs.dev). ␊ ␊ ␊ - ␊ - [](Bookmark,https://affine.pro/)␊ - ␊ - ␊ - [](Bookmark,https://www.youtube.com/@affinepro)␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ `, title: 'Write, Draw, Plan all at Once.', + unknownBlocks: [], } diff --git a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.snap b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.snap index c2e0c3a45c..dc0fdb3923 100644 Binary files a/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.snap and b/packages/backend/server/src/core/doc/__tests__/__snapshots__/reader-from-rpc.spec.ts.snap differ diff --git a/packages/backend/server/src/core/doc/__tests__/event.spec.ts b/packages/backend/server/src/core/doc/__tests__/event.spec.ts index d1483b90c1..eea4402b55 100644 --- a/packages/backend/server/src/core/doc/__tests__/event.spec.ts +++ b/packages/backend/server/src/core/doc/__tests__/event.spec.ts @@ -68,7 +68,7 @@ test('should update doc content to database when doc is updated', async t => { const docId = randomUUID(); await adapter.pushDocUpdates(workspace.id, docId, updates); - await adapter.getDoc(workspace.id, docId); + await adapter.getDocBinNative(workspace.id, docId); mock.method(docReader, 'parseDocContent', () => { return { @@ -181,3 +181,22 @@ test('should ignore update workspace content to database when parse workspace co t.is(content!.name, null); t.is(content!.avatarKey, null); }); + +test('should ignore stale workspace when updating doc meta from snapshot event', async t => { + const { docReader, listener, models } = t.context; + const docId = randomUUID(); + mock.method(docReader, 'parseDocContent', () => ({ + title: 'test title', + summary: 'test summary', + })); + + await models.workspace.delete(workspace.id); + + await t.notThrowsAsync(async () => { + await listener.markDocContentCacheStale({ + workspaceId: workspace.id, + docId, + blob: Buffer.from([0x01]), + }); + }); +}); diff --git a/packages/backend/server/src/core/doc/adapters/workspace.ts b/packages/backend/server/src/core/doc/adapters/workspace.ts index 6e8ce11af9..ae00fb5986 100644 --- a/packages/backend/server/src/core/doc/adapters/workspace.ts +++ b/packages/backend/server/src/core/doc/adapters/workspace.ts @@ -110,7 +110,7 @@ export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter { }); if (isNewDoc) { - this.event.emit('doc.created', { + this.event.emitDetached('doc.created', { workspaceId, docId, editor: editorId, @@ -334,7 +334,7 @@ export class PgWorkspaceDocStorageAdapter extends DocStorageAdapter { }); if (updatedSnapshot) { - this.event.emit('doc.snapshot.updated', { + this.event.emitDetached('doc.snapshot.updated', { workspaceId: snapshot.spaceId, docId: snapshot.docId, blob, diff --git a/packages/backend/server/src/core/doc/event.ts b/packages/backend/server/src/core/doc/event.ts index a9982fa954..c3d42c068d 100644 --- a/packages/backend/server/src/core/doc/event.ts +++ b/packages/backend/server/src/core/doc/event.ts @@ -1,12 +1,29 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, Logger } from '@nestjs/common'; +import { Prisma } from '@prisma/client'; import { OnEvent } from '../../base'; import { Models } from '../../models'; import { PgWorkspaceDocStorageAdapter } from './adapters/workspace'; import { DocReader } from './reader'; +const IGNORED_PRISMA_CODES = new Set(['P2003', 'P2025', 'P2028']); + +function isIgnorableDocEventError(error: unknown) { + if (error instanceof Prisma.PrismaClientKnownRequestError) { + return IGNORED_PRISMA_CODES.has(error.code); + } + if (error instanceof Prisma.PrismaClientUnknownRequestError) { + return /transaction is aborted|transaction already closed/i.test( + error.message + ); + } + return false; +} + @Injectable() export class DocEventsListener { + private readonly logger = new Logger(DocEventsListener.name); + constructor( private readonly docReader: DocReader, private readonly models: Models, @@ -20,21 +37,39 @@ export class DocEventsListener { blob, }: Events['doc.snapshot.updated']) { await this.docReader.markDocContentCacheStale(workspaceId, docId); + const workspace = await this.models.workspace.get(workspaceId); + if (!workspace) { + this.logger.warn( + `Skip stale doc snapshot event for missing workspace ${workspaceId}/${docId}` + ); + return; + } const isDoc = workspaceId !== docId; // update doc content to database - if (isDoc) { - const content = this.docReader.parseDocContent(blob); - if (!content) { + try { + if (isDoc) { + const content = this.docReader.parseDocContent(blob); + if (!content) { + return; + } + await this.models.doc.upsertMeta(workspaceId, docId, content); + } else { + // update workspace content to database + const content = this.docReader.parseWorkspaceContent(blob); + if (!content) { + return; + } + await this.models.workspace.update(workspaceId, content); + } + } catch (error) { + if (isIgnorableDocEventError(error)) { + const message = error instanceof Error ? error.message : String(error); + this.logger.warn( + `Ignore stale doc snapshot event for ${workspaceId}/${docId}: ${message}` + ); return; } - await this.models.doc.upsertMeta(workspaceId, docId, content); - } else { - // update workspace content to database - const content = this.docReader.parseWorkspaceContent(blob); - if (!content) { - return; - } - await this.models.workspace.update(workspaceId, content); + throw error; } } diff --git a/packages/backend/server/src/core/doc/reader.ts b/packages/backend/server/src/core/doc/reader.ts index a8f03ed5f9..c671f33171 100644 --- a/packages/backend/server/src/core/doc/reader.ts +++ b/packages/backend/server/src/core/doc/reader.ts @@ -32,6 +32,8 @@ export interface WorkspaceDocInfo { export interface DocMarkdown { title: string; markdown: string; + knownUnsupportedBlocks: string[]; + unknownBlocks: string[]; } export abstract class DocReader { @@ -185,12 +187,27 @@ export class DatabaseDocReader extends DocReader { if (!doc) { return null; } - return parseDocToMarkdownFromDocSnapshot( - workspaceId, - docId, - doc.bin, - aiEditable - ); + try { + const markdown = parseDocToMarkdownFromDocSnapshot( + workspaceId, + docId, + doc.bin, + aiEditable + ); + + const unknownBlocks = markdown.unknownBlocks ?? []; + if (unknownBlocks.length > 0) { + this.logger.warn( + `Unknown blocks found when parsing markdown for ${workspaceId}/${docId}.`, + { unknownBlocks } + ); + } + + return markdown; + } catch (error) { + this.logger.error(`Failed to parse ${workspaceId}/${docId}.`, error); + throw error; + } } async getDocDiff( diff --git a/packages/backend/server/src/core/mail/config.ts b/packages/backend/server/src/core/mail/config.ts index 97c61f41a7..fdac7b4eb1 100644 --- a/packages/backend/server/src/core/mail/config.ts +++ b/packages/backend/server/src/core/mail/config.ts @@ -31,8 +31,8 @@ declare global { defineModuleConfig('mailer', { 'SMTP.name': { - desc: 'Name of the email server (e.g. your domain name)', - default: 'AFFiNE Server', + desc: 'Hostname used for SMTP HELO/EHLO (e.g. mail.example.com). Leave empty to use the system hostname.', + default: '', env: 'MAILER_SERVERNAME', }, 'SMTP.host': { @@ -72,8 +72,8 @@ defineModuleConfig('mailer', { shape: z.array(z.string()), }, 'fallbackSMTP.name': { - desc: 'Name of the fallback email server (e.g. your domain name)', - default: 'AFFiNE Server', + desc: 'Hostname used for fallback SMTP HELO/EHLO (e.g. mail.example.com). Leave empty to use the system hostname.', + default: '', }, 'fallbackSMTP.host': { desc: 'Host of the email server (e.g. smtp.gmail.com)', diff --git a/packages/backend/server/src/core/mail/sender.ts b/packages/backend/server/src/core/mail/sender.ts index 70bf39007e..aadcaa976d 100644 --- a/packages/backend/server/src/core/mail/sender.ts +++ b/packages/backend/server/src/core/mail/sender.ts @@ -9,6 +9,7 @@ import { import SMTPTransport from 'nodemailer/lib/smtp-transport'; import { Config, metrics, OnEvent } from '../../base'; +import { resolveSMTPHeloHostname } from './utils'; export type SendOptions = Omit & { to: string; @@ -19,8 +20,10 @@ export type SendOptions = Omit & { function configToSMTPOptions( config: AppConfig['mailer']['SMTP'] ): SMTPTransport.Options { + const name = resolveSMTPHeloHostname(config.name); + return { - name: config.name, + ...(name ? { name } : {}), host: config.host, port: config.port, tls: { diff --git a/packages/backend/server/src/core/mail/utils.ts b/packages/backend/server/src/core/mail/utils.ts new file mode 100644 index 0000000000..33adf329a5 --- /dev/null +++ b/packages/backend/server/src/core/mail/utils.ts @@ -0,0 +1,53 @@ +import { isIP } from 'node:net'; +import { hostname as getHostname } from 'node:os'; + +const hostnameLabelRegexp = /^[A-Za-z0-9-]+$/; + +function isValidSMTPAddressLiteral(hostname: string) { + if (!hostname.startsWith('[') || !hostname.endsWith(']')) return false; + + const literal = hostname.slice(1, -1); + if (!literal || literal.includes(' ')) return false; + if (isIP(literal) === 4) return true; + + if (literal.startsWith('IPv6:')) { + return isIP(literal.slice('IPv6:'.length)) === 6; + } + + return false; +} + +export function normalizeSMTPHeloHostname(hostname: string) { + const normalized = hostname.trim().replace(/\.$/, ''); + if (!normalized) return undefined; + if (isValidSMTPAddressLiteral(normalized)) return normalized; + if (normalized.length > 253) return undefined; + + const labels = normalized.split('.'); + for (const label of labels) { + if (!label || label.length > 63) return undefined; + if ( + !hostnameLabelRegexp.test(label) || + label.startsWith('-') || + label.endsWith('-') + ) { + return undefined; + } + } + + return normalized; +} + +function readSystemHostname() { + try { + return getHostname(); + } catch { + return ''; + } +} + +export function resolveSMTPHeloHostname(configuredName: string) { + const normalizedConfiguredName = normalizeSMTPHeloHostname(configuredName); + if (normalizedConfiguredName) return normalizedConfiguredName; + return normalizeSMTPHeloHostname(readSystemHostname()); +} diff --git a/packages/backend/server/src/core/permission/__tests__/event.spec.ts b/packages/backend/server/src/core/permission/__tests__/event.spec.ts new file mode 100644 index 0000000000..21803c0532 --- /dev/null +++ b/packages/backend/server/src/core/permission/__tests__/event.spec.ts @@ -0,0 +1,77 @@ +import { randomUUID } from 'node:crypto'; + +import ava, { TestFn } from 'ava'; + +import { + createTestingModule, + type TestingModule, +} from '../../../__tests__/utils'; +import { DocRole, Models, User, Workspace } from '../../../models'; +import { EventsListener } from '../event'; +import { PermissionModule } from '../index'; + +interface Context { + module: TestingModule; + models: Models; + listener: EventsListener; +} + +const test = ava as TestFn; + +let owner: User; +let workspace: Workspace; + +test.before(async t => { + const module = await createTestingModule({ imports: [PermissionModule] }); + t.context.module = module; + t.context.models = module.get(Models); + t.context.listener = module.get(EventsListener); +}); + +test.beforeEach(async t => { + await t.context.module.initTestingDB(); + owner = await t.context.models.user.create({ + email: `${randomUUID()}@affine.pro`, + }); + workspace = await t.context.models.workspace.create(owner.id); +}); + +test.after.always(async t => { + await t.context.module.close(); +}); + +test('should ignore default owner event when workspace does not exist', async t => { + await t.notThrowsAsync(async () => { + await t.context.listener.setDefaultPageOwner({ + workspaceId: randomUUID(), + docId: randomUUID(), + editor: owner.id, + }); + }); +}); + +test('should ignore default owner event when editor does not exist', async t => { + await t.notThrowsAsync(async () => { + await t.context.listener.setDefaultPageOwner({ + workspaceId: workspace.id, + docId: randomUUID(), + editor: randomUUID(), + }); + }); +}); + +test('should set owner when workspace and editor exist', async t => { + const docId = randomUUID(); + await t.context.listener.setDefaultPageOwner({ + workspaceId: workspace.id, + docId, + editor: owner.id, + }); + + const role = await t.context.models.docUser.get( + workspace.id, + docId, + owner.id + ); + t.is(role?.type, DocRole.Owner); +}); diff --git a/packages/backend/server/src/core/permission/event.ts b/packages/backend/server/src/core/permission/event.ts index 4e6651a145..442fd6877a 100644 --- a/packages/backend/server/src/core/permission/event.ts +++ b/packages/backend/server/src/core/permission/event.ts @@ -1,10 +1,27 @@ -import { Injectable } from '@nestjs/common'; +import { Injectable, Logger } from '@nestjs/common'; +import { Prisma } from '@prisma/client'; import { OnEvent } from '../../base'; import { Models } from '../../models'; +const IGNORED_PRISMA_CODES = new Set(['P2003', 'P2025', 'P2028']); + +function isIgnorablePermissionEventError(error: unknown) { + if (error instanceof Prisma.PrismaClientKnownRequestError) { + return IGNORED_PRISMA_CODES.has(error.code); + } + if (error instanceof Prisma.PrismaClientUnknownRequestError) { + return /transaction is aborted|transaction already closed/i.test( + error.message + ); + } + return false; +} + @Injectable() export class EventsListener { + private readonly logger = new Logger(EventsListener.name); + constructor(private readonly models: Models) {} @OnEvent('doc.created') @@ -15,6 +32,33 @@ export class EventsListener { return; } - await this.models.docUser.setOwner(workspaceId, docId, editor); + const workspace = await this.models.workspace.get(workspaceId); + if (!workspace) { + this.logger.warn( + `Skip default doc owner event for missing workspace ${workspaceId}/${docId}` + ); + return; + } + + const user = await this.models.user.get(editor); + if (!user) { + this.logger.warn( + `Skip default doc owner event for missing editor ${workspaceId}/${docId}/${editor}` + ); + return; + } + + try { + await this.models.docUser.setOwner(workspaceId, docId, editor); + } catch (error) { + if (isIgnorablePermissionEventError(error)) { + const message = error instanceof Error ? error.message : String(error); + this.logger.warn( + `Ignore stale doc owner event for ${workspaceId}/${docId}/${editor}: ${message}` + ); + return; + } + throw error; + } } } diff --git a/packages/backend/server/src/core/telemetry/ga4-client.ts b/packages/backend/server/src/core/telemetry/ga4-client.ts index 6afa73fdb1..7543b29a8b 100644 --- a/packages/backend/server/src/core/telemetry/ga4-client.ts +++ b/packages/backend/server/src/core/telemetry/ga4-client.ts @@ -1,3 +1,5 @@ +import { Logger } from '@nestjs/common'; + import { CleanedTelemetryEvent, Scalar } from './cleaner'; const GA4_ENDPOINT = 'https://www.google-analytics.com/mp/collect'; @@ -14,6 +16,7 @@ type Ga4Payload = { }; export class Ga4Client { + private readonly logger = new Logger(Ga4Client.name); constructor( private readonly measurementId: string, private readonly apiSecret: string, @@ -42,8 +45,21 @@ export class Ga4Client { timestamp_micros: event.timestampMicros, })), }; - - await this.post(payload); + try { + await this.post(payload); + } catch { + if ( + env.DEPLOYMENT_TYPE === 'affine' && + env.NODE_ENV === 'production' + ) { + // In production, we want to be resilient to GA4 failures, so we catch and ignore errors. + // In non-production environments, we rethrow to surface issues during development and testing. + this.logger.log( + 'Failed to send telemetry event to GA4:', + chunk.map(e => e.eventName).join(', ') + ); + } + } } } } diff --git a/packages/backend/server/src/core/user/resolver.ts b/packages/backend/server/src/core/user/resolver.ts index 565be2b4c2..275db70f32 100644 --- a/packages/backend/server/src/core/user/resolver.ts +++ b/packages/backend/server/src/core/user/resolver.ts @@ -17,6 +17,8 @@ import { isNil, omitBy } from 'lodash-es'; import { CannotDeleteOwnAccount, type FileUpload, + ImageFormatNotSupported, + OneMB, readBufferWithLimit, sniffMime, Throttle, @@ -28,6 +30,7 @@ import { UserFeatureName, UserSettingsSchema, } from '../../models'; +import { processImage } from '../../native'; import { Public } from '../auth/guard'; import { sessionUser } from '../auth/service'; import { CurrentUser } from '../auth/session'; @@ -115,16 +118,26 @@ export class UserResolver { throw new UserNotFound(); } - const avatarBuffer = await readBufferWithLimit(avatar.createReadStream()); - const contentType = sniffMime(avatarBuffer, avatar.mimetype); + const avatarBuffer = await readBufferWithLimit( + avatar.createReadStream(), + 5 * OneMB + ); + const contentType = sniffMime(avatarBuffer, avatar.mimetype)?.toLowerCase(); if (!contentType || !contentType.startsWith('image/')) { - throw new Error(`Invalid file type: ${contentType || 'unknown'}`); + throw new ImageFormatNotSupported({ format: contentType || 'unknown' }); + } + + let processedAvatarBuffer: Buffer; + try { + processedAvatarBuffer = await processImage(avatarBuffer, 512, false); + } catch { + throw new ImageFormatNotSupported({ format: contentType }); } const avatarUrl = await this.storage.put( `${user.id}-avatar-${Date.now()}`, - avatarBuffer, - { contentType } + processedAvatarBuffer, + { contentType: 'image/webp' } ); if (user.avatarUrl) { diff --git a/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.md b/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.md index 998a1ed4e2..e1288ed3e7 100644 --- a/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.md +++ b/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.md @@ -1379,6 +1379,16 @@ Generated by [AVA](https://avajs.dev). > Snapshot 1 { + knownUnsupportedBlocks: [ + 'RX4CG2zsBk:affine:note', + 'S1mkc8zUoU:affine:note', + 'yGlBdshAqN:affine:note', + '6lDiuDqZGL:affine:note', + 'cauvaHOQmh:affine:note', + '2jwCeO8Yot:affine:note', + 'c9MF_JiRgx:affine:note', + '6x7ALjUDjj:affine:surface', + ], markdown: `AFFiNE is an open source all in one workspace, an operating system for all the building blocks of your team wiki, knowledge management and digital assets and a better alternative to Notion and Miro.␊ ␊ ␊ @@ -1440,35 +1450,9 @@ Generated by [AVA](https://avajs.dev). ␊ ␊ ␊ - ␊ - [](Bookmark,https://affine.pro/)␊ - ␊ - ␊ - [](Bookmark,https://www.youtube.com/@affinepro)␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ `, title: 'Write, Draw, Plan all at Once.', + unknownBlocks: [], } ## can parse doc to markdown from doc snapshot with ai editable @@ -1476,6 +1460,16 @@ Generated by [AVA](https://avajs.dev). > Snapshot 1 { + knownUnsupportedBlocks: [ + 'RX4CG2zsBk:affine:note', + 'S1mkc8zUoU:affine:note', + 'yGlBdshAqN:affine:note', + '6lDiuDqZGL:affine:note', + 'cauvaHOQmh:affine:note', + '2jwCeO8Yot:affine:note', + 'c9MF_JiRgx:affine:note', + '6x7ALjUDjj:affine:surface', + ], markdown: `␊ AFFiNE is an open source all in one workspace, an operating system for all the building blocks of your team wiki, knowledge management and digital assets and a better alternative to Notion and Miro.␊ ␊ @@ -1565,38 +1559,7 @@ Generated by [AVA](https://avajs.dev). ␊ ␊ ␊ - ␊ - ␊ - [](Bookmark,https://affine.pro/)␊ - ␊ - ␊ - ␊ - [](Bookmark,https://www.youtube.com/@affinepro)␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ - ␊ `, title: 'Write, Draw, Plan all at Once.', + unknownBlocks: [], } diff --git a/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.snap b/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.snap index 9833dcdc06..cdac9bc904 100644 Binary files a/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.snap and b/packages/backend/server/src/core/utils/__tests__/__snapshots__/blocksute.spec.ts.snap differ diff --git a/packages/backend/server/src/core/utils/blocksuite.ts b/packages/backend/server/src/core/utils/blocksuite.ts index d6532b3df2..d0b66a3680 100644 --- a/packages/backend/server/src/core/utils/blocksuite.ts +++ b/packages/backend/server/src/core/utils/blocksuite.ts @@ -16,6 +16,13 @@ export interface WorkspaceDocContent { avatarKey: string; } +export interface DocMarkdownContent { + title: string; + markdown: string; + knownUnsupportedBlocks: string[]; + unknownBlocks: string[]; +} + export interface ParsePageOptions { maxSummaryLength?: number; } @@ -74,7 +81,7 @@ export function parseDocToMarkdownFromDocSnapshot( docId: string, docSnapshot: Uint8Array, aiEditable = false -) { +): DocMarkdownContent { const docUrlPrefix = workspaceId ? `/workspace/${workspaceId}` : undefined; const parsed = parseYDocToMarkdown( Buffer.from(docSnapshot), @@ -86,5 +93,7 @@ export function parseDocToMarkdownFromDocSnapshot( return { title: parsed.title, markdown: parsed.markdown, + knownUnsupportedBlocks: parsed.knownUnsupportedBlocks ?? [], + unknownBlocks: parsed.unknownBlocks ?? [], }; } diff --git a/packages/backend/server/src/models/copilot-session.ts b/packages/backend/server/src/models/copilot-session.ts index 0cc1ee9b4a..b5d177e27b 100644 --- a/packages/backend/server/src/models/copilot-session.ts +++ b/packages/backend/server/src/models/copilot-session.ts @@ -10,6 +10,7 @@ import { CopilotSessionNotFound, } from '../base'; import { getTokenEncoder } from '../native'; +import type { PromptAttachment } from '../plugins/copilot/providers/types'; import { BaseModel } from './base'; export enum SessionType { @@ -24,7 +25,7 @@ type ChatPrompt = { model: string; }; -type ChatAttachment = { attachment: string; mimeType: string } | string; +type ChatAttachment = PromptAttachment; type ChatStreamObject = { type: 'text-delta' | 'reasoning' | 'tool-call' | 'tool-result'; @@ -173,22 +174,105 @@ export class CopilotSessionModel extends BaseModel { } return attachments - .map(attachment => - typeof attachment === 'string' - ? (this.sanitizeString(attachment) ?? '') - : { - attachment: - this.sanitizeString(attachment.attachment) ?? - attachment.attachment, + .map(attachment => { + if (typeof attachment === 'string') { + return this.sanitizeString(attachment) ?? ''; + } + + if ('attachment' in attachment) { + return { + attachment: + this.sanitizeString(attachment.attachment) ?? + attachment.attachment, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + }; + } + + switch (attachment.kind) { + case 'url': + return { + ...attachment, + url: this.sanitizeString(attachment.url) ?? attachment.url, mimeType: this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, - } - ) + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + case 'data': + case 'bytes': + return { + ...attachment, + data: this.sanitizeString(attachment.data) ?? attachment.data, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + case 'file_handle': + return { + ...attachment, + fileHandle: + this.sanitizeString(attachment.fileHandle) ?? + attachment.fileHandle, + mimeType: + this.sanitizeString(attachment.mimeType) ?? attachment.mimeType, + fileName: + this.sanitizeString(attachment.fileName) ?? attachment.fileName, + providerHint: attachment.providerHint + ? { + provider: + this.sanitizeString(attachment.providerHint.provider) ?? + attachment.providerHint.provider, + kind: + this.sanitizeString(attachment.providerHint.kind) ?? + attachment.providerHint.kind, + } + : undefined, + }; + } + + return attachment; + }) .filter(attachment => { if (typeof attachment === 'string') { return !!attachment; } - return !!attachment.attachment && !!attachment.mimeType; + if ('attachment' in attachment) { + return !!attachment.attachment && !!attachment.mimeType; + } + + switch (attachment.kind) { + case 'url': + return !!attachment.url; + case 'data': + case 'bytes': + return !!attachment.data && !!attachment.mimeType; + case 'file_handle': + return !!attachment.fileHandle; + } + + return false; }); } diff --git a/packages/backend/server/src/models/doc-user.ts b/packages/backend/server/src/models/doc-user.ts index d8a71e3bda..f189fd5d01 100644 --- a/packages/backend/server/src/models/doc-user.ts +++ b/packages/backend/server/src/models/doc-user.ts @@ -2,6 +2,7 @@ import assert from 'node:assert'; import { Injectable } from '@nestjs/common'; import { Transactional } from '@nestjs-cls/transactional'; +import type { TransactionalAdapterPrisma } from '@nestjs-cls/transactional-adapter-prisma'; import { WorkspaceDocUserRole } from '@prisma/client'; import { CanNotBatchGrantDocOwnerPermissions, PaginationInput } from '../base'; @@ -14,31 +15,20 @@ export class DocUserModel extends BaseModel { * Set or update the [Owner] of a doc. * The old [Owner] will be changed to [Manager] if there is already an [Owner]. */ - @Transactional() + @Transactional({ timeout: 15000 }) async setOwner(workspaceId: string, docId: string, userId: string) { - const oldOwner = await this.db.workspaceDocUserRole.findFirst({ + await this.db.workspaceDocUserRole.updateMany({ where: { workspaceId, docId, type: DocRole.Owner, + userId: { not: userId }, + }, + data: { + type: DocRole.Manager, }, }); - if (oldOwner) { - await this.db.workspaceDocUserRole.update({ - where: { - workspaceId_docId_userId: { - workspaceId, - docId, - userId: oldOwner.userId, - }, - }, - data: { - type: DocRole.Manager, - }, - }); - } - await this.db.workspaceDocUserRole.upsert({ where: { workspaceId_docId_userId: { @@ -57,16 +47,9 @@ export class DocUserModel extends BaseModel { type: DocRole.Owner, }, }); - - if (oldOwner) { - this.logger.log( - `Transfer doc owner of [${workspaceId}/${docId}] from [${oldOwner.userId}] to [${userId}]` - ); - } else { - this.logger.log( - `Set doc owner of [${workspaceId}/${docId}] to [${userId}]` - ); - } + this.logger.log( + `Set doc owner of [${workspaceId}/${docId}] to [${userId}]` + ); } /** diff --git a/packages/backend/server/src/native.ts b/packages/backend/server/src/native.ts index 7563eb7a5a..1eb481df31 100644 --- a/packages/backend/server/src/native.ts +++ b/packages/backend/server/src/native.ts @@ -40,6 +40,7 @@ export function getTokenEncoder(model?: string | null): Tokenizer | null { export const getMime = serverNativeModule.getMime; export const parseDoc = serverNativeModule.parseDoc; export const htmlSanitize = serverNativeModule.htmlSanitize; +export const processImage = serverNativeModule.processImage; export const parseYDocFromBinary = serverNativeModule.parseDocFromBinary; export const parseYDocToMarkdown = serverNativeModule.parseDocToMarkdown; export const parsePageDocFromBinary = serverNativeModule.parsePageDoc; @@ -57,3 +58,461 @@ export const addDocToRootDoc = serverNativeModule.addDocToRootDoc; export const updateDocTitle = serverNativeModule.updateDocTitle; export const updateDocProperties = serverNativeModule.updateDocProperties; export const updateRootDocMetaTitle = serverNativeModule.updateRootDocMetaTitle; + +type NativeLlmModule = { + llmDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmStructuredDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmEmbeddingDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmRerankDispatch?: ( + protocol: string, + backendConfigJson: string, + requestJson: string + ) => string | Promise; + llmDispatchStream?: ( + protocol: string, + backendConfigJson: string, + requestJson: string, + callback: (error: Error | null, eventJson: string) => void + ) => { abort?: () => void } | undefined; +}; + +const nativeLlmModule = serverNativeModule as typeof serverNativeModule & + NativeLlmModule; + +export type NativeLlmProtocol = + | 'openai_chat' + | 'openai_responses' + | 'anthropic' + | 'gemini'; + +export type NativeLlmBackendConfig = { + base_url: string; + auth_token: string; + request_layer?: + | 'anthropic' + | 'chat_completions' + | 'responses' + | 'vertex' + | 'vertex_anthropic' + | 'gemini_api' + | 'gemini_vertex'; + headers?: Record; + no_streaming?: boolean; + timeout_ms?: number; +}; + +export type NativeLlmCoreRole = 'system' | 'user' | 'assistant' | 'tool'; + +export type NativeLlmCoreContent = + | { type: 'text'; text: string } + | { type: 'reasoning'; text: string; signature?: string } + | { + type: 'tool_call'; + call_id: string; + name: string; + arguments: Record; + arguments_text?: string; + arguments_error?: string; + thought?: string; + } + | { + type: 'tool_result'; + call_id: string; + output: unknown; + is_error?: boolean; + name?: string; + arguments?: Record; + arguments_text?: string; + arguments_error?: string; + } + | { type: 'image'; source: Record | string } + | { type: 'audio'; source: Record | string } + | { type: 'file'; source: Record | string }; + +export type NativeLlmCoreMessage = { + role: NativeLlmCoreRole; + content: NativeLlmCoreContent[]; +}; + +export type NativeLlmToolDefinition = { + name: string; + description?: string; + parameters: Record; +}; + +export type NativeLlmRequest = { + model: string; + messages: NativeLlmCoreMessage[]; + stream?: boolean; + max_tokens?: number; + temperature?: number; + tools?: NativeLlmToolDefinition[]; + tool_choice?: 'auto' | 'none' | 'required' | { name: string }; + include?: string[]; + reasoning?: Record; + response_schema?: Record; + middleware?: { + request?: Array< + 'normalize_messages' | 'clamp_max_tokens' | 'tool_schema_rewrite' + >; + stream?: Array<'stream_event_normalize' | 'citation_indexing'>; + config?: { + additional_properties_policy?: 'preserve' | 'forbid'; + property_format_policy?: 'preserve' | 'drop'; + property_min_length_policy?: 'preserve' | 'drop'; + array_min_items_policy?: 'preserve' | 'drop'; + array_max_items_policy?: 'preserve' | 'drop'; + max_tokens_cap?: number; + }; + }; +}; + +export type NativeLlmStructuredRequest = { + model: string; + messages: NativeLlmCoreMessage[]; + schema: Record; + max_tokens?: number; + temperature?: number; + reasoning?: Record; + strict?: boolean; + response_mime_type?: string; + middleware?: NativeLlmRequest['middleware']; +}; + +export type NativeLlmEmbeddingRequest = { + model: string; + inputs: string[]; + dimensions?: number; + task_type?: string; +}; + +export type NativeLlmRerankCandidate = { + id?: string; + text: string; +}; + +export type NativeLlmRerankRequest = { + model: string; + query: string; + candidates: NativeLlmRerankCandidate[]; + top_n?: number; +}; + +export type NativeLlmDispatchResponse = { + id: string; + model: string; + message: NativeLlmCoreMessage; + usage: { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cached_tokens?: number; + }; + finish_reason: + | 'stop' + | 'length' + | 'tool_calls' + | 'content_filter' + | 'error' + | string; + reasoning_details?: unknown; +}; + +export type NativeLlmStructuredResponse = { + id: string; + model: string; + output_text: string; + usage: NativeLlmDispatchResponse['usage']; + finish_reason: NativeLlmDispatchResponse['finish_reason']; + reasoning_details?: unknown; +}; + +export type NativeLlmEmbeddingResponse = { + model: string; + embeddings: number[][]; + usage?: { + prompt_tokens: number; + total_tokens: number; + }; +}; + +export type NativeLlmRerankResponse = { + model: string; + scores: number[]; +}; + +export type NativeLlmStreamEvent = + | { type: 'message_start'; id?: string; model?: string } + | { type: 'text_delta'; text: string } + | { type: 'reasoning_delta'; text: string } + | { + type: 'tool_call_delta'; + call_id: string; + name?: string; + arguments_delta: string; + } + | { + type: 'tool_call'; + call_id: string; + name: string; + arguments: Record; + arguments_text?: string; + arguments_error?: string; + thought?: string; + } + | { + type: 'tool_result'; + call_id: string; + output: unknown; + is_error?: boolean; + name?: string; + arguments?: Record; + arguments_text?: string; + arguments_error?: string; + } + | { type: 'citation'; index: number; url: string } + | { + type: 'usage'; + usage: { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cached_tokens?: number; + }; + } + | { + type: 'done'; + finish_reason?: NativeLlmDispatchResponse['finish_reason']; + usage?: { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + cached_tokens?: number; + }; + } + | { type: 'error'; message: string; code?: string; raw?: string }; +const LLM_STREAM_END_MARKER = '__AFFINE_LLM_STREAM_END__'; + +export async function llmDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmRequest +): Promise { + if (!nativeLlmModule.llmDispatch) { + throw new Error('native llm dispatch is not available'); + } + const response = nativeLlmModule.llmDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmDispatchResponse; +} + +export async function llmStructuredDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmStructuredRequest +): Promise { + if (!nativeLlmModule.llmStructuredDispatch) { + throw new Error('native llm structured dispatch is not available'); + } + const response = nativeLlmModule.llmStructuredDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmStructuredResponse; +} + +export async function llmEmbeddingDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmEmbeddingRequest +): Promise { + if (!nativeLlmModule.llmEmbeddingDispatch) { + throw new Error('native llm embedding dispatch is not available'); + } + const response = nativeLlmModule.llmEmbeddingDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmEmbeddingResponse; +} + +export async function llmRerankDispatch( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmRerankRequest +): Promise { + if (!nativeLlmModule.llmRerankDispatch) { + throw new Error('native llm rerank dispatch is not available'); + } + const response = nativeLlmModule.llmRerankDispatch( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request) + ); + const responseText = await Promise.resolve(response); + return JSON.parse(responseText) as NativeLlmRerankResponse; +} + +export class NativeStreamAdapter implements AsyncIterableIterator { + readonly #queue: T[] = []; + readonly #waiters: ((result: IteratorResult) => void)[] = []; + readonly #handle: { abort?: () => void } | undefined; + readonly #signal?: AbortSignal; + readonly #abortListener?: () => void; + #ended = false; + + constructor( + handle: { abort?: () => void } | undefined, + signal?: AbortSignal + ) { + this.#handle = handle; + this.#signal = signal; + + if (signal?.aborted) { + this.close(true); + return; + } + + if (signal) { + this.#abortListener = () => { + this.close(true); + }; + signal.addEventListener('abort', this.#abortListener, { once: true }); + } + } + + private close(abortHandle: boolean) { + if (this.#ended) { + return; + } + + this.#ended = true; + if (this.#signal && this.#abortListener) { + this.#signal.removeEventListener('abort', this.#abortListener); + } + if (abortHandle) { + this.#handle?.abort?.(); + } + + while (this.#waiters.length) { + const waiter = this.#waiters.shift(); + waiter?.({ value: undefined as T, done: true }); + } + } + + push(value: T | null) { + if (this.#ended) { + return; + } + + if (value === null) { + this.close(false); + return; + } + + const waiter = this.#waiters.shift(); + if (waiter) { + waiter({ value, done: false }); + return; + } + + this.#queue.push(value); + } + + [Symbol.asyncIterator]() { + return this; + } + + async next(): Promise> { + if (this.#queue.length > 0) { + const value = this.#queue.shift() as T; + return { value, done: false }; + } + + if (this.#ended) { + return { value: undefined as T, done: true }; + } + + return await new Promise(resolve => { + this.#waiters.push(resolve); + }); + } + + async return(): Promise> { + this.close(true); + + return { value: undefined as T, done: true }; + } +} + +export function llmDispatchStream( + protocol: NativeLlmProtocol, + backendConfig: NativeLlmBackendConfig, + request: NativeLlmRequest, + signal?: AbortSignal +): AsyncIterableIterator { + if (!nativeLlmModule.llmDispatchStream) { + throw new Error('native llm stream dispatch is not available'); + } + + let adapter: NativeStreamAdapter | undefined; + const buffer: (NativeLlmStreamEvent | null)[] = []; + let pushFn = (event: NativeLlmStreamEvent | null) => { + buffer.push(event); + }; + const handle = nativeLlmModule.llmDispatchStream( + protocol, + JSON.stringify(backendConfig), + JSON.stringify(request), + (error, eventJson) => { + if (error) { + pushFn({ type: 'error', message: error.message, raw: eventJson }); + return; + } + if (eventJson === LLM_STREAM_END_MARKER) { + pushFn(null); + return; + } + try { + pushFn(JSON.parse(eventJson) as NativeLlmStreamEvent); + } catch (error) { + pushFn({ + type: 'error', + message: + error instanceof Error + ? error.message + : 'failed to parse native stream event', + raw: eventJson, + }); + } + } + ); + adapter = new NativeStreamAdapter(handle, signal); + pushFn = event => { + adapter.push(event); + }; + for (const event of buffer) { + adapter.push(event); + } + return adapter; +} diff --git a/packages/backend/server/src/plugins/calendar/providers/def.ts b/packages/backend/server/src/plugins/calendar/providers/def.ts index a9089d2ab6..02ad30482a 100644 --- a/packages/backend/server/src/plugins/calendar/providers/def.ts +++ b/packages/backend/server/src/plugins/calendar/providers/def.ts @@ -154,8 +154,8 @@ export abstract class CalendarProvider { protected async fetchJson(url: string, init?: RequestInit) { const response = await fetch(url, { - headers: { Accept: 'application/json', ...init?.headers }, ...init, + headers: { ...init?.headers, Accept: 'application/json' }, }); const body = await response.text(); if (!response.ok) { diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index 0f40790f67..aa9e0feffe 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -1,3 +1,5 @@ +import { z } from 'zod'; + import { defineModuleConfig, StorageJSONSchema, @@ -13,7 +15,180 @@ import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini'; import { MorphConfig } from './providers/morph'; import { OpenAIConfig } from './providers/openai'; import { PerplexityConfig } from './providers/perplexity'; -import { VertexSchema } from './providers/types'; +import { + CopilotProviderType, + ModelOutputType, + VertexSchema, +} from './providers/types'; + +export type CopilotProviderConfigMap = { + [CopilotProviderType.OpenAI]: OpenAIConfig; + [CopilotProviderType.FAL]: FalConfig; + [CopilotProviderType.Gemini]: GeminiGenerativeConfig; + [CopilotProviderType.GeminiVertex]: GeminiVertexConfig; + [CopilotProviderType.Perplexity]: PerplexityConfig; + [CopilotProviderType.Anthropic]: AnthropicOfficialConfig; + [CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig; + [CopilotProviderType.Morph]: MorphConfig; +}; + +export type ProviderSpecificConfig = + CopilotProviderConfigMap[keyof CopilotProviderConfigMap]; + +export const RustRequestMiddlewareValues = [ + 'normalize_messages', + 'clamp_max_tokens', + 'tool_schema_rewrite', +] as const; +export type RustRequestMiddleware = + (typeof RustRequestMiddlewareValues)[number]; + +export const RustStreamMiddlewareValues = [ + 'stream_event_normalize', + 'citation_indexing', +] as const; +export type RustStreamMiddleware = (typeof RustStreamMiddlewareValues)[number]; + +export const NodeTextMiddlewareValues = [ + 'citation_footnote', + 'callout', + 'thinking_format', +] as const; +export type NodeTextMiddleware = (typeof NodeTextMiddlewareValues)[number]; + +export type ProviderMiddlewareConfig = { + rust?: { request?: RustRequestMiddleware[]; stream?: RustStreamMiddleware[] }; + node?: { text?: NodeTextMiddleware[] }; +}; + +type CopilotProviderProfileCommon = { + id: string; + displayName?: string; + priority?: number; + enabled?: boolean; + models?: string[]; + middleware?: ProviderMiddlewareConfig; +}; + +type CopilotProviderProfileVariant = { + type: T; + config: CopilotProviderConfigMap[T]; +}; + +export type CopilotProviderProfile = CopilotProviderProfileCommon & + { + [Type in CopilotProviderType]: CopilotProviderProfileVariant; + }[CopilotProviderType]; + +export type CopilotProviderDefaults = Partial< + Record, string> +> & { + fallback?: string; +}; + +const CopilotProviderProfileBaseShape = z.object({ + id: z.string().regex(/^[a-zA-Z0-9-_]+$/), + displayName: z.string().optional(), + priority: z.number().optional(), + enabled: z.boolean().optional(), + models: z.array(z.string()).optional(), + middleware: z + .object({ + rust: z + .object({ + request: z.array(z.enum(RustRequestMiddlewareValues)).optional(), + stream: z.array(z.enum(RustStreamMiddlewareValues)).optional(), + }) + .optional(), + node: z + .object({ text: z.array(z.enum(NodeTextMiddlewareValues)).optional() }) + .optional(), + }) + .optional(), +}); + +const OpenAIConfigShape = z.object({ + apiKey: z.string(), + baseURL: z.string().optional(), + oldApiStyle: z.boolean().optional(), +}); + +const FalConfigShape = z.object({ + apiKey: z.string(), +}); + +const GeminiGenerativeConfigShape = z.object({ + apiKey: z.string(), + baseURL: z.string().optional(), +}); + +const VertexProviderConfigShape = z.object({ + location: z.string().optional(), + project: z.string().optional(), + baseURL: z.string().optional(), + googleAuthOptions: z.any().optional(), + fetch: z.any().optional(), +}); + +const PerplexityConfigShape = z.object({ + apiKey: z.string(), + endpoint: z.string().optional(), +}); + +const AnthropicOfficialConfigShape = z.object({ + apiKey: z.string(), + baseURL: z.string().optional(), +}); + +const MorphConfigShape = z.object({ + apiKey: z.string().optional(), +}); + +const CopilotProviderProfileShape = z.discriminatedUnion('type', [ + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.OpenAI), + config: OpenAIConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.FAL), + config: FalConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.Gemini), + config: GeminiGenerativeConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.GeminiVertex), + config: VertexProviderConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.Perplexity), + config: PerplexityConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.Anthropic), + config: AnthropicOfficialConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.AnthropicVertex), + config: VertexProviderConfigShape, + }), + CopilotProviderProfileBaseShape.extend({ + type: z.literal(CopilotProviderType.Morph), + config: MorphConfigShape, + }), +]); + +const CopilotProviderDefaultsShape = z.object({ + [ModelOutputType.Text]: z.string().optional(), + [ModelOutputType.Object]: z.string().optional(), + [ModelOutputType.Embedding]: z.string().optional(), + [ModelOutputType.Image]: z.string().optional(), + [ModelOutputType.Rerank]: z.string().optional(), + [ModelOutputType.Structured]: z.string().optional(), + fallback: z.string().optional(), +}); + declare global { interface AppConfigSchema { copilot: { @@ -27,6 +202,8 @@ declare global { storage: ConfigItem; scenarios: ConfigItem; providers: { + profiles: ConfigItem; + defaults: ConfigItem; openai: ConfigItem; fal: ConfigItem; gemini: ConfigItem; @@ -54,15 +231,24 @@ defineModuleConfig('copilot', { chat: 'gemini-2.5-flash', embedding: 'gemini-embedding-001', image: 'gpt-image-1', - rerank: 'gpt-4.1', coding: 'claude-sonnet-4-5@20250929', - complex_text_generation: 'gpt-4o-2024-08-06', + complex_text_generation: 'gpt-5-mini', quick_decision_making: 'gpt-5-mini', quick_text_generation: 'gemini-2.5-flash', polish_and_summarize: 'gemini-2.5-flash', }, }, }, + 'providers.profiles': { + desc: 'The profile list for copilot providers.', + default: [], + shape: z.array(CopilotProviderProfileShape), + }, + 'providers.defaults': { + desc: 'The default provider ids for model output types and global fallback.', + default: {}, + shape: CopilotProviderDefaultsShape, + }, 'providers.openai': { desc: 'The config for the openai provider.', default: { diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 921e50a088..d76ea50591 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -36,10 +36,7 @@ import { BlobNotFound, CallMetric, Config, - CopilotFailedToGenerateText, CopilotSessionNotFound, - InternalServerError, - mapAnyError, mapSseError, metrics, NoCopilotProviderAvailable, @@ -242,61 +239,6 @@ export class CopilotController implements BeforeApplicationShutdown { }; } - @Get('/chat/:sessionId') - @CallMetric('ai', 'chat', { timer: true }) - async chat( - @CurrentUser() user: CurrentUser, - @Req() req: Request, - @Param('sessionId') sessionId: string, - @Query() query: Record - ): Promise { - const info: any = { sessionId, params: query }; - - try { - const { provider, model, session, finalMessage } = - await this.prepareChatSession( - user, - sessionId, - query, - ModelOutputType.Text - ); - - info.model = model; - info.finalMessage = finalMessage.filter(m => m.role !== 'system'); - metrics.ai.counter('chat_calls').add(1, { model }); - - const { reasoning, webSearch, toolsConfig } = - ChatQuerySchema.parse(query); - const content = await provider.text({ modelId: model }, finalMessage, { - ...session.config.promptConfig, - signal: getSignal(req).signal, - user: user.id, - session: session.config.sessionId, - workspace: session.config.workspaceId, - reasoning, - webSearch, - tools: getTools(session.config.promptConfig?.tools, toolsConfig), - }); - - session.push({ - role: 'assistant', - content, - createdAt: new Date(), - }); - await session.save(); - - return content; - } catch (e: any) { - metrics.ai.counter('chat_errors').add(1); - let error = mapAnyError(e); - if (error instanceof InternalServerError) { - error = new CopilotFailedToGenerateText(e.message); - } - error.log('CopilotChat', info); - throw error; - } - } - @Sse('/chat/:sessionId/stream') @CallMetric('ai', 'chat_stream', { timer: true }) async chatStream( diff --git a/packages/backend/server/src/plugins/copilot/embedding/client.ts b/packages/backend/server/src/plugins/copilot/embedding/client.ts index 8268cae01e..34e5ccfa5f 100644 --- a/packages/backend/server/src/plugins/copilot/embedding/client.ts +++ b/packages/backend/server/src/plugins/copilot/embedding/client.ts @@ -1,21 +1,17 @@ import { Logger } from '@nestjs/common'; import type { ModuleRef } from '@nestjs/core'; -import { - Config, - CopilotPromptNotFound, - CopilotProviderNotSupported, -} from '../../../base'; +import { Config, CopilotProviderNotSupported } from '../../../base'; import { CopilotFailedToGenerateEmbedding } from '../../../base/error/errors.gen'; import { ChunkSimilarity, Embedding, EMBEDDING_DIMENSIONS, } from '../../../models'; -import { PromptService } from '../prompt/service'; import { CopilotProviderFactory } from '../providers/factory'; import type { CopilotProvider } from '../providers/provider'; import { + type CopilotRerankRequest, type ModelFullConditions, ModelInputType, ModelOutputType, @@ -23,24 +19,20 @@ import { import { EmbeddingClient, type ReRankResult } from './types'; const EMBEDDING_MODEL = 'gemini-embedding-001'; -const RERANK_PROMPT = 'Rerank results'; - +const RERANK_MODEL = 'gpt-5.2'; class ProductionEmbeddingClient extends EmbeddingClient { private readonly logger = new Logger(ProductionEmbeddingClient.name); constructor( private readonly config: Config, - private readonly providerFactory: CopilotProviderFactory, - private readonly prompt: PromptService + private readonly providerFactory: CopilotProviderFactory ) { super(); } override async configured(): Promise { const embedding = await this.providerFactory.getProvider({ - modelId: this.config.copilot?.scenarios?.override_enabled - ? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL - : EMBEDDING_MODEL, + modelId: this.getEmbeddingModelId(), outputType: ModelOutputType.Embedding, }); const result = Boolean(embedding); @@ -65,9 +57,15 @@ class ProductionEmbeddingClient extends EmbeddingClient { return provider; } + private getEmbeddingModelId() { + return this.config.copilot?.scenarios?.override_enabled + ? this.config.copilot.scenarios.scenarios?.embedding || EMBEDDING_MODEL + : EMBEDDING_MODEL; + } + async getEmbeddings(input: string[]): Promise { const provider = await this.getProvider({ - modelId: EMBEDDING_MODEL, + modelId: this.getEmbeddingModelId(), outputType: ModelOutputType.Embedding, }); this.logger.verbose( @@ -110,15 +108,22 @@ class ProductionEmbeddingClient extends EmbeddingClient { ): Promise { if (!embeddings.length) return []; - const prompt = await this.prompt.get(RERANK_PROMPT); - if (!prompt) { - throw new CopilotPromptNotFound({ name: RERANK_PROMPT }); - } - const provider = await this.getProvider({ modelId: prompt.model }); + const provider = await this.getProvider({ + modelId: RERANK_MODEL, + outputType: ModelOutputType.Rerank, + }); + + const rerankRequest: CopilotRerankRequest = { + query, + candidates: embeddings.map((embedding, index) => ({ + id: String(index), + text: embedding.content, + })), + }; const ranks = await provider.rerank( - { modelId: prompt.model }, - embeddings.map(e => prompt.finish({ query, doc: e.content })), + { modelId: RERANK_MODEL }, + rerankRequest, { signal } ); @@ -171,7 +176,7 @@ class ProductionEmbeddingClient extends EmbeddingClient { ); try { - // 4.1 mini's context windows large enough to handle all embeddings + // The rerank prompt is expected to handle the full deduped candidate list. const ranks = await this.getEmbeddingRelevance( query, sortedEmbeddings, @@ -217,9 +222,7 @@ export async function getEmbeddingClient( const providerFactory = moduleRef.get(CopilotProviderFactory, { strict: false, }); - const prompt = moduleRef.get(PromptService, { strict: false }); - - const client = new ProductionEmbeddingClient(config, providerFactory, prompt); + const client = new ProductionEmbeddingClient(config, providerFactory); if (await client.configured()) { EMBEDDING_CLIENT = client; } diff --git a/packages/backend/server/src/plugins/copilot/mcp/controller.ts b/packages/backend/server/src/plugins/copilot/mcp/controller.ts index e738d9d31f..0688bbc4bd 100644 --- a/packages/backend/server/src/plugins/copilot/mcp/controller.ts +++ b/packages/backend/server/src/plugins/copilot/mcp/controller.ts @@ -1,4 +1,3 @@ -import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import { Controller, Delete, @@ -13,28 +12,51 @@ import { } from '@nestjs/common'; import type { Request, Response } from 'express'; +import { Throttle } from '../../../base'; import { CurrentUser } from '../../../core/auth'; -import { WorkspaceMcpProvider } from './provider'; +import { WorkspaceMcpProvider, type WorkspaceMcpServer } from './provider'; + +type JsonRpcId = string | number | null; + +type JsonRpcErrorResponse = { + jsonrpc: '2.0'; + error: { code: number; message: string }; + id: JsonRpcId; +}; + +type JsonRpcSuccessResponse = { + jsonrpc: '2.0'; + result: Record; + id: JsonRpcId; +}; + +type JsonRpcResponse = JsonRpcErrorResponse | JsonRpcSuccessResponse; + +const JSON_RPC_VERSION = '2.0'; +const MAX_BATCH_SIZE = 20; +const DEFAULT_PROTOCOL_VERSION = '2025-03-26'; +const SUPPORTED_PROTOCOL_VERSIONS = new Set([ + '2025-11-25', + '2025-06-18', + '2025-03-26', + '2024-11-05', + '2024-10-07', +]); @Controller('/api/workspaces/:workspaceId/mcp') export class WorkspaceMcpController { private readonly logger = new Logger(WorkspaceMcpController.name); + constructor(private readonly provider: WorkspaceMcpProvider) {} @Get('/') @Delete('/') @HttpCode(HttpStatus.METHOD_NOT_ALLOWED) async STATELESS_MCP_ENDPOINT() { - return { - jsonrpc: '2.0', - error: { - code: -32000, - message: 'Method not allowed.', - }, - id: null, - }; + return this.errorResponse(null, -32000, 'Method not allowed.'); } + @Throttle('default') @Post('/') async mcp( @Req() req: Request, @@ -42,28 +64,202 @@ export class WorkspaceMcpController { @CurrentUser() user: CurrentUser, @Param('workspaceId') workspaceId: string ) { - let server = await this.provider.for(user.id, workspaceId); - - const transport: StreamableHTTPServerTransport = - new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); - - const cleanup = () => { - transport.close().catch(e => { - this.logger.error('Failed to close MCP transport', e); - }); - server.close().catch(e => { - this.logger.error('Failed to close MCP server', e); - }); - }; + const abortController = new AbortController(); + req.on('close', () => abortController.abort()); try { - res.on('close', cleanup); - await server.connect(transport); - await transport.handleRequest(req, res, req.body); - } catch { - cleanup(); + const server = await this.provider.for(user.id, workspaceId); + const body = req.body as unknown; + const isBatch = Array.isArray(body); + const messages = isBatch ? body : [body]; + + if (!messages.length) { + res + .status(HttpStatus.BAD_REQUEST) + .json(this.errorResponse(null, -32600, 'Invalid Request')); + return; + } + if (messages.length > MAX_BATCH_SIZE) { + res + .status(HttpStatus.BAD_REQUEST) + .json( + this.errorResponse( + null, + -32600, + `Batch size exceeds limit (${MAX_BATCH_SIZE}).` + ) + ); + return; + } + + const responses: JsonRpcResponse[] = []; + for (const message of messages) { + const response = await this.handleMessage( + message, + server, + abortController.signal + ); + if (response) { + responses.push(response); + } + } + + if (!responses.length) { + res.status(HttpStatus.ACCEPTED).send(); + return; + } + + res.status(HttpStatus.OK).json(isBatch ? responses : responses[0]); + } catch (error) { + this.logger.error('Failed to handle MCP request', error); + res + .status(HttpStatus.INTERNAL_SERVER_ERROR) + .json(this.errorResponse(null, -32603, 'Internal error')); } } + + private async handleMessage( + message: unknown, + server: WorkspaceMcpServer, + signal: AbortSignal + ): Promise { + const rawRequest = this.asObject(message); + if (!rawRequest || rawRequest.jsonrpc !== JSON_RPC_VERSION) { + return this.errorResponse(null, -32600, 'Invalid Request'); + } + + const method = rawRequest.method; + if (typeof method !== 'string') { + return this.errorResponse(null, -32600, 'Invalid Request'); + } + + const id = this.parseRequestId(rawRequest.id); + if (id === 'invalid') { + return this.errorResponse(null, -32600, 'Invalid Request'); + } + + const isNotification = id === undefined; + const responseId = id ?? null; + + switch (method) { + case 'initialize': { + const params = this.asObject(rawRequest.params); + const requestedVersion = + params && typeof params.protocolVersion === 'string' + ? params.protocolVersion + : DEFAULT_PROTOCOL_VERSION; + const protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.has( + requestedVersion + ) + ? requestedVersion + : DEFAULT_PROTOCOL_VERSION; + + if (isNotification) return null; + + return this.successResponse(responseId, { + protocolVersion, + capabilities: { tools: {} }, + serverInfo: { name: server.name, version: server.version }, + }); + } + + case 'notifications/initialized': + case 'ping': { + if (isNotification) { + return null; + } + return this.successResponse(responseId, {}); + } + + case 'tools/list': { + if (isNotification) { + return null; + } + return this.successResponse(responseId, { + tools: server.tools.map(tool => ({ + name: tool.name, + title: tool.title, + description: tool.description, + inputSchema: tool.inputSchema, + })), + }); + } + + case 'tools/call': { + const params = this.asObject(rawRequest.params); + if (!params || typeof params.name !== 'string') { + return this.errorResponse(responseId, -32602, 'Invalid params'); + } + + const tool = server.tools.find(item => item.name === params.name); + if (!tool) { + return this.errorResponse( + responseId, + -32602, + `Tool not found: ${params.name}` + ); + } + + const args = this.asObject(params.arguments) ?? {}; + try { + const result = await tool.execute(args, { signal }); + if (isNotification) return null; + + return this.successResponse( + responseId, + result as Record + ); + } catch (error) { + this.logger.error( + `Error executing tool in mcp ${tool.name}`, + error instanceof Error ? error.stack : String(error) + ); + return this.errorResponse( + responseId, + -32001, + `Error executing tool: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + default: { + if (isNotification) return null; + return this.errorResponse(responseId, -32601, 'Method not found'); + } + } + } + + private asObject(value: unknown): Record | null { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + return null; + } + return value as Record; + } + + private parseRequestId(value: unknown): JsonRpcId | undefined | 'invalid' { + if (value === undefined) return undefined; + if ( + value === null || + typeof value === 'string' || + typeof value === 'number' + ) { + return value; + } + return 'invalid'; + } + + private successResponse( + id: JsonRpcId, + result: Record + ): JsonRpcSuccessResponse { + return { jsonrpc: JSON_RPC_VERSION, result, id }; + } + + private errorResponse( + id: JsonRpcId, + code: number, + message: string + ): JsonRpcErrorResponse { + return { jsonrpc: JSON_RPC_VERSION, error: { code, message }, id }; + } } diff --git a/packages/backend/server/src/plugins/copilot/mcp/provider.ts b/packages/backend/server/src/plugins/copilot/mcp/provider.ts index bd15acc6cc..73f14d4f86 100644 --- a/packages/backend/server/src/plugins/copilot/mcp/provider.ts +++ b/packages/backend/server/src/plugins/copilot/mcp/provider.ts @@ -1,5 +1,3 @@ -import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; -import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; import { Injectable } from '@nestjs/common'; import { pick } from 'lodash-es'; import z from 'zod/v3'; @@ -10,6 +8,94 @@ import { clearEmbeddingChunk } from '../../../models'; import { IndexerService } from '../../indexer'; import { CopilotContextService } from '../context/service'; +type McpTextContent = { + type: 'text'; + text: string; +}; + +export type WorkspaceMcpToolResult = { + content: McpTextContent[]; + isError?: boolean; +}; + +export type WorkspaceMcpToolDefinition = { + name: string; + title: string; + description: string; + inputSchema: Record; + execute: ( + args: Record, + options: { signal: AbortSignal } + ) => Promise; +}; + +export type WorkspaceMcpServer = { + name: string; + version: string; + tools: WorkspaceMcpToolDefinition[]; +}; + +type ToolExecutorInput = { + name: string; + title: string; + description: string; + parser: T; + inputSchema: Record; + execute: ( + args: z.infer, + options: { signal: AbortSignal } + ) => Promise; +}; + +function toolText(text: string): WorkspaceMcpToolResult { + return { + content: [{ type: 'text', text }], + }; +} + +function toolError(message: string): WorkspaceMcpToolResult { + return { + isError: true, + content: [{ type: 'text', text: message }], + }; +} + +function toInputError(error: z.ZodError) { + const details = error.issues + .map(issue => { + const path = issue.path.join('.'); + return path ? `${path}: ${issue.message}` : issue.message; + }) + .join('; '); + return toolError(`Invalid arguments: ${details || 'Invalid input'}`); +} + +function abortIfNeeded( + signal: AbortSignal +): WorkspaceMcpToolResult | undefined { + if (signal.aborted) return toolError('Request aborted.'); + return; +} + +function defineTool( + config: ToolExecutorInput +): WorkspaceMcpToolDefinition { + return { + name: config.name, + title: config.title, + description: config.description, + inputSchema: config.inputSchema, + execute: async (args, options) => { + const aborted = abortIfNeeded(options.signal); + if (aborted) return aborted; + + const parsed = config.parser.safeParse(args ?? {}); + if (!parsed.success) return toInputError(parsed.error); + return await config.execute(parsed.data, options); + }, + }; +} + @Injectable() export class WorkspaceMcpProvider { constructor( @@ -20,190 +106,182 @@ export class WorkspaceMcpProvider { private readonly indexer: IndexerService ) {} - async for(userId: string, workspaceId: string) { + async for(userId: string, workspaceId: string): Promise { await this.ac.user(userId).workspace(workspaceId).assert('Workspace.Read'); - const server = new McpServer({ - name: `AFFiNE MCP Server for Workspace ${workspaceId}`, - version: '1.0.0', - }); - - server.registerTool( - 'read_document', - { - title: 'Read Document', - description: 'Read a document with given ID', - inputSchema: z.object({ - docId: z.string(), - }), + const readDocument = defineTool({ + name: 'read_document', + title: 'Read Document', + description: 'Read a document with given ID', + parser: z.object({ docId: z.string() }), + inputSchema: { + type: 'object', + properties: { + docId: { type: 'string' }, + }, + required: ['docId'], + additionalProperties: false, }, - async ({ docId }) => { - const notFoundError: CallToolResult = { - isError: true, - content: [ - { - type: 'text', - text: `Doc with id ${docId} not found.`, - }, - ], - }; + execute: async ({ docId }, options) => { + const notFoundError = toolError(`Doc with id ${docId} not found.`); const accessible = await this.ac .user(userId) .workspace(workspaceId) .doc(docId) .can('Doc.Read'); + if (!accessible) return notFoundError; - if (!accessible) { - return notFoundError; - } + const abortedAfterPermission = abortIfNeeded(options.signal); + if (abortedAfterPermission) return abortedAfterPermission; const content = await this.reader.getDocMarkdown( workspaceId, docId, false ); + if (!content) return notFoundError; - if (!content) { - return notFoundError; - } + const abortedAfterRead = abortIfNeeded(options.signal); + if (abortedAfterRead) return abortedAfterRead; - return { - content: [ - { - type: 'text', - text: content.markdown, - }, - ], - } as const; - } - ); - - server.registerTool( - 'semantic_search', - { - title: 'Semantic Search', - description: - 'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts, recent documents).', - inputSchema: z.object({ - query: z.string(), - }), + return toolText(content.markdown); }, - async ({ query }, req) => { - query = query.trim(); - if (!query) { - return { - isError: true, - content: [ - { - type: 'text', - text: 'Query is required for semantic search.', - }, - ], - }; + }); + + const semanticSearch = defineTool({ + name: 'semantic_search', + title: 'Semantic Search', + description: + 'Retrieve conceptually related passages by performing vector-based semantic similarity search across embedded documents; use this tool only when exact keyword search fails or the user explicitly needs meaning-level matches (e.g., paraphrases, synonyms, broader concepts, recent documents).', + parser: z.object({ query: z.string() }), + inputSchema: { + type: 'object', + properties: { + query: { type: 'string' }, + }, + required: ['query'], + additionalProperties: false, + }, + execute: async ({ query }, options) => { + const trimmed = query.trim(); + if (!trimmed) { + return toolError('Query is required for semantic search.'); } const chunks = await this.context.matchWorkspaceDocs( workspaceId, - query, + trimmed, 5, - req.signal + options.signal ); + const abortedAfterMatch = abortIfNeeded(options.signal); + if (abortedAfterMatch) return abortedAfterMatch; + const docs = await this.ac .user(userId) .workspace(workspaceId) .docs( - chunks.filter(c => 'docId' in c), + chunks.filter(chunk => 'docId' in chunk), 'Doc.Read' ); + const abortedAfterDocs = abortIfNeeded(options.signal); + if (abortedAfterDocs) return abortedAfterDocs; + return { content: docs.map(doc => ({ type: 'text', text: clearEmbeddingChunk(doc).content, })), - } as const; - } - ); - - server.registerTool( - 'keyword_search', - { - title: 'Keyword Search', - description: - 'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.', - inputSchema: z.object({ - query: z.string(), - }), + }; }, - async ({ query }) => { - query = query.trim(); - if (!query) { - return { - isError: true, - content: [ - { - type: 'text', - text: 'Query is required for keyword search.', - }, - ], - }; - } + }); + + const keywordSearch = defineTool({ + name: 'keyword_search', + title: 'Keyword Search', + description: + 'Fuzzy search all workspace documents for the exact keyword or phrase supplied and return passages ranked by textual match. Use this tool by default whenever a straightforward term-based or keyword-base lookup is sufficient.', + parser: z.object({ query: z.string() }), + inputSchema: { + type: 'object', + properties: { + query: { type: 'string' }, + }, + required: ['query'], + additionalProperties: false, + }, + execute: async ({ query }, options) => { + const trimmed = query.trim(); + if (!trimmed) return toolError('Query is required for keyword search.'); + + let docs = await this.indexer.searchDocsByKeyword(workspaceId, trimmed); + + const abortedAfterSearch = abortIfNeeded(options.signal); + if (abortedAfterSearch) return abortedAfterSearch; - let docs = await this.indexer.searchDocsByKeyword(workspaceId, query); docs = await this.ac .user(userId) .workspace(workspaceId) .docs(docs, 'Doc.Read'); + const abortedAfterDocs = abortIfNeeded(options.signal); + if (abortedAfterDocs) return abortedAfterDocs; + return { content: docs.map(doc => ({ type: 'text', text: JSON.stringify(pick(doc, 'docId', 'title', 'createdAt')), })), - } as const; - } - ); + }; + }, + }); + + const tools = [readDocument, semanticSearch, keywordSearch]; if (env.dev || env.namespaces.canary) { - // Write tools - create and update documents - server.registerTool( - 'create_document', - { - title: 'Create Document', - description: - 'Create a new document in the workspace with the given title and markdown content. Returns the ID of the created document. This tool not support insert or update database block and image yet.', - inputSchema: z.object({ - title: z.string().min(1).describe('The title of the new document'), - content: z - .string() - .describe('The markdown content for the document body'), - }), + const createDocument = defineTool({ + name: 'create_document', + title: 'Create Document', + description: + 'Create a new document in the workspace with the given title and markdown content. Returns the ID of the created document. This tool not support insert or update database block and image yet.', + parser: z.object({ + title: z.string().min(1), + content: z.string(), + }), + inputSchema: { + type: 'object', + properties: { + title: { + type: 'string', + description: 'The title of the new document', + }, + content: { + type: 'string', + description: 'The markdown content for the document body', + }, + }, + required: ['title', 'content'], + additionalProperties: false, }, - async ({ title, content }) => { + execute: async ({ title, content }, options) => { try { - // Check if user can create docs in this workspace await this.ac .user(userId) .workspace(workspaceId) .assert('Workspace.CreateDoc'); - // Sanitize title by removing newlines and trimming - const sanitizedTitle = title.replace(/[\r\n]+/g, ' ').trim(); - if (!sanitizedTitle) { - throw new Error('Title cannot be empty'); - } + const abortedAfterPermission = abortIfNeeded(options.signal); + if (abortedAfterPermission) return abortedAfterPermission; - // Strip any leading H1 from content to prevent duplicates - // Per CommonMark spec, ATX headings allow only 0-3 spaces before the # - // Handles: "# Title", " # Title", "# Title #" + const sanitizedTitle = title.replace(/[\r\n]+/g, ' ').trim(); + if (!sanitizedTitle) throw new Error('Title cannot be empty'); const strippedContent = content.replace( /^[ \t]{0,3}#\s+[^\n]*#*\s*\n*/, '' ); - - // Create the document const result = await this.writer.createDoc( workspaceId, sanitizedTitle, @@ -211,173 +289,145 @@ export class WorkspaceMcpProvider { userId ); - return { - content: [ - { - type: 'text', - text: JSON.stringify({ - success: true, - docId: result.docId, - message: `Document "${title}" created successfully`, - }), - }, - ], - } as const; + return toolText( + JSON.stringify({ + success: true, + docId: result.docId, + message: `Document "${title}" created successfully`, + }) + ); } catch (error) { - return { - isError: true, - content: [ - { - type: 'text', - text: `Failed to create document: ${error instanceof Error ? error.message : 'Unknown error'}`, - }, - ], - }; + return toolError( + `Failed to create document: ${error instanceof Error ? error.message : 'Unknown error'}` + ); } - } - ); - - server.registerTool( - 'update_document', - { - title: 'Update Document', - description: - 'Update an existing document with new markdown content (body only). Uses structural diffing to apply minimal changes, preserving document history and enabling real-time collaboration. This does NOT update the document title. This tool not support insert or update database block and image yet.', - inputSchema: z.object({ - docId: z.string().describe('The ID of the document to update'), - content: z - .string() - .describe( - 'The complete new markdown content for the document body (do NOT include a title H1)' - ), - }), }, - async ({ docId, content }) => { - const notFoundError: CallToolResult = { - isError: true, - content: [ - { - type: 'text', - text: `Doc with id ${docId} not found.`, - }, - ], - }; + }); + + const updateDocument = defineTool({ + name: 'update_document', + title: 'Update Document', + description: + 'Update an existing document with new markdown content (body only). Uses structural diffing to apply minimal changes, preserving document history and enabling real-time collaboration. This does NOT update the document title. This tool not support insert or update database block and image yet.', + parser: z.object({ + docId: z.string(), + content: z.string(), + }), + inputSchema: { + type: 'object', + properties: { + docId: { + type: 'string', + description: 'The ID of the document to update', + }, + content: { + type: 'string', + description: + 'The complete new markdown content for the document body (do NOT include a title H1)', + }, + }, + required: ['docId', 'content'], + additionalProperties: false, + }, + execute: async ({ docId, content }, options) => { + const notFoundError = toolError(`Doc with id ${docId} not found.`); - // Use can() instead of assert() to avoid leaking doc existence info const accessible = await this.ac .user(userId) .workspace(workspaceId) .doc(docId) .can('Doc.Update'); + if (!accessible) return notFoundError; - if (!accessible) { - return notFoundError; - } + const abortedBeforeWrite = abortIfNeeded(options.signal); + if (abortedBeforeWrite) return abortedBeforeWrite; try { - // Update the document await this.writer.updateDoc(workspaceId, docId, content, userId); - - return { - content: [ - { - type: 'text', - text: JSON.stringify({ - success: true, - docId, - message: `Document updated successfully`, - }), - }, - ], - } as const; + return toolText( + JSON.stringify({ + success: true, + docId, + message: 'Document updated successfully', + }) + ); } catch (error) { - return { - isError: true, - content: [ - { - type: 'text', - text: `Failed to update document: ${error instanceof Error ? error.message : 'Unknown error'}`, - }, - ], - }; + return toolError( + `Failed to update document: ${error instanceof Error ? error.message : 'Unknown error'}` + ); } - } - ); - - server.registerTool( - 'update_document_meta', - { - title: 'Update Document Metadata', - description: 'Update document metadata (currently title only).', - inputSchema: z.object({ - docId: z.string().describe('The ID of the document to update'), - title: z.string().min(1).describe('The new document title'), - }), }, - async ({ docId, title }) => { - const notFoundError: CallToolResult = { - isError: true, - content: [ - { - type: 'text', - text: `Doc with id ${docId} not found.`, - }, - ], - }; + }); + + const updateDocumentMeta = defineTool({ + name: 'update_document_meta', + title: 'Update Document Metadata', + description: 'Update document metadata (currently title only).', + parser: z.object({ + docId: z.string(), + title: z.string().min(1), + }), + inputSchema: { + type: 'object', + properties: { + docId: { + type: 'string', + description: 'The ID of the document to update', + }, + title: { + type: 'string', + description: 'The new document title', + }, + }, + required: ['docId', 'title'], + additionalProperties: false, + }, + execute: async ({ docId, title }, options) => { + const notFoundError = toolError(`Doc with id ${docId} not found.`); - // Use can() instead of assert() to avoid leaking doc existence info const accessible = await this.ac .user(userId) .workspace(workspaceId) .doc(docId) .can('Doc.Update'); + if (!accessible) return notFoundError; - if (!accessible) { - return notFoundError; - } + const abortedAfterPermission = abortIfNeeded(options.signal); + if (abortedAfterPermission) return abortedAfterPermission; try { const sanitizedTitle = title.replace(/[\r\n]+/g, ' ').trim(); - if (!sanitizedTitle) { - throw new Error('Title cannot be empty'); - } + if (!sanitizedTitle) throw new Error('Title cannot be empty'); await this.writer.updateDocMeta( workspaceId, docId, - { - title: sanitizedTitle, - }, + { title: sanitizedTitle }, userId ); - return { - content: [ - { - type: 'text', - text: JSON.stringify({ - success: true, - docId, - message: `Document title updated successfully`, - }), - }, - ], - } as const; + return toolText( + JSON.stringify({ + success: true, + docId, + message: 'Document title updated successfully', + }) + ); } catch (error) { - return { - isError: true, - content: [ - { - type: 'text', - text: `Failed to update document metadata: ${error instanceof Error ? error.message : 'Unknown error'}`, - }, - ], - }; + return toolError( + `Failed to update document metadata: ${error instanceof Error ? error.message : 'Unknown error'}` + ); } - } - ); + }, + }); + + tools.push(createDocument, updateDocument, updateDocumentMeta); } - return server; + return { + name: `AFFiNE MCP Server for Workspace ${workspaceId}`, + version: '1.0.1', + tools, + }; } } diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 74d4acb911..6696138fc6 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -3,7 +3,7 @@ import { AiPrompt, PrismaClient } from '@prisma/client'; import type { PromptConfig, PromptMessage } from '../providers/types'; -type Prompt = Omit< +export type Prompt = Omit< AiPrompt, | 'id' | 'createdAt' @@ -34,7 +34,6 @@ export const Scenario = { 'Remove background', 'Upscale image', ], - rerank: ['Rerank results'], coding: [ 'Apply Updates', 'Code Artifact', @@ -124,7 +123,7 @@ const workflows: Prompt[] = [ { name: 'workflow:presentation:step2', action: 'workflow:presentation:step2', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -143,7 +142,7 @@ const workflows: Prompt[] = [ { name: 'workflow:presentation:step4', action: 'workflow:presentation:step4', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -187,7 +186,7 @@ const workflows: Prompt[] = [ { name: 'workflow:brainstorm:step2', action: 'workflow:brainstorm:step2', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', config: { frequencyPenalty: 0.5, presencePenalty: 0.5, @@ -197,7 +196,8 @@ const workflows: Prompt[] = [ messages: [ { role: 'system', - content: `You are the creator of the mind map. You need to analyze and expand on the input and output it according to the indentation formatting template given below without redundancy.\nBelow is an example of indentation for a mind map, the title and content needs to be removed by text replacement and not retained. Please strictly adhere to the hierarchical indentation of the template and my requirements, bold, headings and other formatting (e.g. #, **) are not allowed, a maximum of five levels of indentation is allowed, and the last node of each node should make a judgment on whether to make a detailed statement or not based on the topic:\nexmaple:\n- {topic}\n - {Level 1}\n - {Level 2}\n - {Level 3}\n - {Level 4}\n - {Level 1}\n - {Level 2}\n - {Level 3}\n - {Level 1}\n - {Level 2}\n - {Level 3}`, + content: + 'Use the Markdown nested unordered list syntax without any extra styles or plain text descriptions to analyze and expand the input into a mind map. Regardless of the content, the first-level list should contain only one item, which acts as the root. Each node label must be plain text only. Do not output markdown links, footnotes, citations, URLs, headings, bold text, code fences, or any explanatory text outside the nested list. A maximum of five levels of indentation is allowed.', }, { role: 'assistant', @@ -381,7 +381,11 @@ const textActions: Prompt[] = [ name: 'Transcript audio', action: 'Transcript audio', model: 'gemini-2.5-flash', - optionalModels: ['gemini-2.5-flash', 'gemini-2.5-pro'], + optionalModels: [ + 'gemini-2.5-flash', + 'gemini-2.5-pro', + 'gemini-3.1-pro-preview', + ], messages: [ { role: 'system', @@ -414,25 +418,10 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr maxRetries: 1, }, }, - { - name: 'Rerank results', - action: 'Rerank results', - model: 'gpt-4.1', - messages: [ - { - role: 'system', - content: `Judge whether the Document meets the requirements based on the Query and the Instruct provided. The answer must be "yes" or "no".`, - }, - { - role: 'user', - content: `: Given a document search result, determine whether the result is relevant to the query.\n: {{query}}\n: {{doc}}`, - }, - ], - }, { name: 'Generate a caption', action: 'Generate a caption', - model: 'gpt-5-mini', + model: 'gemini-2.5-flash', messages: [ { role: 'user', @@ -448,7 +437,7 @@ Convert a multi-speaker audio recording into a structured JSON format by transcr { name: 'Conversation Summary', action: 'Conversation Summary', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -473,7 +462,7 @@ Return only the summary text—no headings, labels, or commentary.`, { name: 'Summary', action: 'Summary', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -504,7 +493,7 @@ You are an assistant helping summarize a document. Use this format, replacing te { name: 'Summary as title', action: 'Summary as title', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -521,7 +510,7 @@ You are an assistant helping summarize a document. Use this format, replacing te { name: 'Summary the webpage', action: 'Summary the webpage', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'user', @@ -533,7 +522,7 @@ You are an assistant helping summarize a document. Use this format, replacing te { name: 'Explain this', action: 'Explain this', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -576,7 +565,7 @@ A concise paragraph that captures the article's main argument and key conclusion { name: 'Explain this image', action: 'Explain this image', - model: 'gpt-4.1-2025-04-14', + model: 'gemini-2.5-flash', messages: [ { role: 'system', @@ -727,7 +716,7 @@ You are a highly accomplished professional translator, demonstrating profound pr { name: 'Summarize the meeting', action: 'Summarize the meeting', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -752,7 +741,7 @@ You are an assistant helping summarize a document. Use this format, replacing te { name: 'Find action for summary', action: 'Find action for summary', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -774,7 +763,7 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Write an article about this', action: 'Write an article about this', - model: 'gemini-2.5-flash', + model: 'gemini-2.5-pro', messages: [ { role: 'system', @@ -829,7 +818,7 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Write a twitter about this', action: 'Write a twitter about this', - model: 'gpt-4.1-2025-04-14', + model: 'gemini-2.5-flash', messages: [ { role: 'system', @@ -915,7 +904,7 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Write a blog post about this', action: 'Write a blog post about this', - model: 'gemini-2.5-flash', + model: 'gemini-2.5-pro', messages: [ { role: 'system', @@ -1005,7 +994,7 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Change tone to', action: 'Change tone', - model: 'gpt-4.1-2025-04-14', + model: 'gemini-2.5-flash', messages: [ { role: 'system', @@ -1096,12 +1085,12 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Brainstorm mindmap', action: 'Brainstorm mindmap', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', messages: [ { role: 'system', content: - 'Use the Markdown nested unordered list syntax without any extra styles or plain text descriptions to brainstorm the questions or topics provided by user for a mind map. Regardless of the content, the first-level list should contain only one item, which acts as the root. Do not wrap everything into a single code block.', + 'Use the Markdown nested unordered list syntax without any extra styles or plain text descriptions to brainstorm the questions or topics provided by user for a mind map. Regardless of the content, the first-level list should contain only one item, which acts as the root. Each node label must be plain text only. Do not output markdown links, footnotes, citations, URLs, headings, bold text, code fences, or any explanatory text outside the nested list.', }, { role: 'user', @@ -1113,12 +1102,12 @@ You are an assistant helping find actions of meeting summary. Use this format, r { name: 'Expand mind map', action: 'Expand mind map', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', messages: [ { role: 'system', content: - 'You are a professional writer. Use the Markdown nested unordered list syntax without any extra styles or plain text descriptions to brainstorm the questions or topics provided by user for a mind map.', + 'You are a professional writer. Use the Markdown nested unordered list syntax without any extra styles or plain text descriptions to expand the selected node in a mind map. The output must be exactly one subtree: the first bullet must repeat the selected node text as the subtree root, and it must include at least one new nested child bullet beneath it. Each node label must be plain text only. Do not output markdown links, footnotes, citations, URLs, headings, bold text, code fences, or any explanatory text outside the nested list.', }, { role: 'user', @@ -1190,7 +1179,7 @@ The output must be perfect. Adherence to every detail of these instructions is n { name: 'Improve grammar for it', action: 'Improve grammar for it', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -1259,7 +1248,7 @@ The output must be perfect. Adherence to every detail of these instructions is n { name: 'Find action items from it', action: 'Find action items from it', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -1283,7 +1272,7 @@ If there are items in the content that can be used as to-do tasks, please refer { name: 'Check code error', action: 'Check code error', - model: 'gpt-4.1-2025-04-14', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -1343,7 +1332,7 @@ If there are items in the content that can be used as to-do tasks, please refer { name: 'Create a presentation', action: 'Create a presentation', - model: 'gpt-4o-2024-08-06', + model: 'gpt-5-mini', messages: [ { role: 'system', @@ -1518,7 +1507,7 @@ When sent new notes, respond ONLY with the contents of the html file.`, { name: 'Continue writing', action: 'Continue writing', - model: 'gemini-2.5-flash', + model: 'gemini-2.5-pro', messages: [ { role: 'system', @@ -1904,6 +1893,7 @@ const CHAT_PROMPT: Omit = { optionalModels: [ 'gemini-2.5-flash', 'gemini-2.5-pro', + 'gemini-3.1-pro-preview', 'claude-sonnet-4-5@20250929', ], messages: [ @@ -2074,7 +2064,11 @@ Below is the user's query. Please respond in the user's preferred language witho 'codeArtifact', 'blobRead', ], - proModels: ['gemini-2.5-pro', 'claude-sonnet-4-5@20250929'], + proModels: [ + 'gemini-2.5-pro', + 'gemini-3.1-pro-preview', + 'claude-sonnet-4-5@20250929', + ], }, }; @@ -2095,17 +2089,14 @@ export const prompts: Prompt[] = [ export async function refreshPrompts(db: PrismaClient) { const needToSkip = await db.aiPrompt - .findMany({ - where: { modified: true }, - select: { name: true }, - }) + .findMany({ where: { modified: true }, select: { name: true } }) .then(p => p.map(p => p.name)); for (const prompt of prompts) { // skip prompt update if already modified by admin panel if (needToSkip.includes(prompt.name)) { new Logger('CopilotPrompt').warn(`Skip modified prompt: ${prompt.name}`); - return; + continue; } await db.aiPrompt.upsert({ diff --git a/packages/backend/server/src/plugins/copilot/prompt/service.ts b/packages/backend/server/src/plugins/copilot/prompt/service.ts index 789c1722a8..c89ba31be7 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/service.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/service.ts @@ -12,6 +12,7 @@ import { import { ChatPrompt } from './chat-prompt'; import { CopilotPromptScenario, + type Prompt, prompts, refreshPrompts, Scenario, @@ -21,6 +22,7 @@ import { export class PromptService implements OnApplicationBootstrap { private readonly logger = new Logger(PromptService.name); private readonly cache = new Map(); + private readonly inMemoryPrompts = new Map(); constructor( private readonly config: Config, @@ -28,7 +30,7 @@ export class PromptService implements OnApplicationBootstrap { ) {} async onApplicationBootstrap() { - this.cache.clear(); + this.resetInMemoryPrompts(); await refreshPrompts(this.db); } @@ -45,6 +47,7 @@ export class PromptService implements OnApplicationBootstrap { } protected async setup(scenarios?: CopilotPromptScenario) { + this.ensureInMemoryPrompts(); if (!!scenarios && scenarios.override_enabled && scenarios.scenarios) { this.logger.log('Updating prompts based on scenarios...'); for (const [scenario, model] of Object.entries(scenarios.scenarios)) { @@ -75,25 +78,29 @@ export class PromptService implements OnApplicationBootstrap { * @returns prompt names */ async listNames() { - return this.db.aiPrompt - .findMany({ select: { name: true } }) - .then(prompts => Array.from(new Set(prompts.map(p => p.name)))); + this.ensureInMemoryPrompts(); + return Array.from(this.inMemoryPrompts.keys()); } async list() { - return this.db.aiPrompt.findMany({ - select: { - name: true, - action: true, - model: true, - config: true, - messages: { - select: { role: true, content: true, params: true }, - orderBy: { idx: 'asc' }, - }, - }, - orderBy: { action: { sort: 'asc', nulls: 'first' } }, - }); + this.ensureInMemoryPrompts(); + return Array.from(this.inMemoryPrompts.values()) + .map(prompt => ({ + name: prompt.name, + action: prompt.action ?? null, + model: prompt.model, + config: prompt.config ? structuredClone(prompt.config) : null, + messages: prompt.messages.map(message => ({ + role: message.role, + content: message.content, + params: message.params ?? null, + })), + })) + .sort((a, b) => { + if (a.action === null && b.action !== null) return -1; + if (a.action !== null && b.action === null) return 1; + return (a.action ?? '').localeCompare(b.action ?? ''); + }); } /** @@ -102,40 +109,24 @@ export class PromptService implements OnApplicationBootstrap { * @returns prompt messages */ async get(name: string): Promise { + this.ensureInMemoryPrompts(); + // skip cache in dev mode to ensure the latest prompt is always fetched if (!env.dev) { const cached = this.cache.get(name); if (cached) return cached; } - const prompt = await this.db.aiPrompt.findUnique({ - where: { - name, - }, - select: { - name: true, - action: true, - model: true, - optionalModels: true, - config: true, - messages: { - select: { - role: true, - content: true, - params: true, - }, - orderBy: { - idx: 'asc', - }, - }, - }, - }); + const prompt = this.inMemoryPrompts.get(name); + if (!prompt) return null; - const messages = PromptMessageSchema.array().safeParse(prompt?.messages); - const config = PromptConfigSchema.safeParse(prompt?.config); - if (prompt && messages.success && config.success) { + const messages = PromptMessageSchema.array().safeParse(prompt.messages); + const config = PromptConfigSchema.safeParse(prompt.config); + if (messages.success && config.success) { const chatPrompt = ChatPrompt.createFromPrompt({ - ...prompt, + ...this.clonePrompt(prompt), + action: prompt.action ?? null, + optionalModels: prompt.optionalModels ?? [], config: config.data, messages: messages.data, }); @@ -149,25 +140,69 @@ export class PromptService implements OnApplicationBootstrap { name: string, model: string, messages: PromptMessage[], - config?: PromptConfig | null + config?: PromptConfig | null, + extraConfig?: { optionalModels: string[] } ) { - return await this.db.aiPrompt - .create({ - data: { - name, - model, - config: config || undefined, - messages: { - create: messages.map((m, idx) => ({ - idx, - ...m, - attachments: m.attachments || undefined, - params: m.params || undefined, - })), + this.ensureInMemoryPrompts(); + + const existing = this.inMemoryPrompts.get(name); + const mergedOptionalModels = existing?.optionalModels + ? [...existing.optionalModels, ...(extraConfig?.optionalModels ?? [])] + : extraConfig?.optionalModels; + const inMemoryConfig = (!!config && structuredClone(config)) || undefined; + const dbConfig = this.toDbConfig(config); + this.inMemoryPrompts.set(name, { + name, + model, + action: existing?.action, + optionalModels: mergedOptionalModels, + config: inMemoryConfig, + messages: this.cloneMessages(messages), + }); + this.cache.delete(name); + + try { + return await this.db.aiPrompt + .upsert({ + where: { name }, + create: { + name, + action: existing?.action, + model, + optionalModels: mergedOptionalModels, + config: dbConfig, + messages: { + create: messages.map((m, idx) => ({ + idx, + ...m, + attachments: m.attachments || undefined, + params: m.params || undefined, + })), + }, }, - }, - }) - .then(ret => ret.id); + update: { + model, + optionalModels: mergedOptionalModels, + config: dbConfig, + updatedAt: new Date(), + messages: { + deleteMany: {}, + create: messages.map((m, idx) => ({ + idx, + ...m, + attachments: m.attachments || undefined, + params: m.params || undefined, + })), + }, + }, + }) + .then(ret => ret.id); + } catch (error) { + this.logger.warn( + `Compat prompt upsert failed for "${name}": ${this.stringifyError(error)}` + ); + return -1; + } } @Transactional() @@ -177,44 +212,123 @@ export class PromptService implements OnApplicationBootstrap { messages?: PromptMessage[]; model?: string; modified?: boolean; - config?: PromptConfig; + config?: PromptConfig | null; }, where?: Prisma.AiPromptWhereInput ) { + this.ensureInMemoryPrompts(); const { config, messages, model, modified } = data; - const existing = await this.db.aiPrompt - .count({ where: { ...where, name } }) - .then(count => count > 0); - if (existing) { - await this.db.aiPrompt.update({ - where: { name }, - data: { - config: config || undefined, - updatedAt: new Date(), - modified, - model, - messages: messages - ? { - // cleanup old messages - deleteMany: {}, - create: messages.map((m, idx) => ({ - idx, - ...m, - attachments: m.attachments || undefined, - params: m.params || undefined, - })), - } - : undefined, - }, - }); + const current = this.inMemoryPrompts.get(name); + if (current) { + const next = this.clonePrompt(current); + if (model !== undefined) { + next.model = model; + } + if (config === null) { + next.config = undefined; + } else if (config !== undefined) { + next.config = structuredClone(config); + } + if (messages) { + next.messages = this.cloneMessages(messages); + } + + this.inMemoryPrompts.set(name, next); this.cache.delete(name); } + + try { + const existing = await this.db.aiPrompt + .count({ where: { ...where, name } }) + .then(count => count > 0); + if (existing) { + await this.db.aiPrompt.update({ + where: { name }, + data: { + config: this.toDbConfig(config), + updatedAt: new Date(), + modified, + model, + messages: messages + ? { + // cleanup old messages + deleteMany: {}, + create: messages.map((m, idx) => ({ + idx, + ...m, + attachments: m.attachments || undefined, + params: m.params || undefined, + })), + } + : undefined, + }, + }); + } + } catch (error) { + this.logger.warn( + `Compat prompt update failed for "${name}": ${this.stringifyError(error)}` + ); + } } async delete(name: string) { - const { id } = await this.db.aiPrompt.delete({ where: { name } }); + this.inMemoryPrompts.delete(name); this.cache.delete(name); - return id; + + try { + const { id } = await this.db.aiPrompt.delete({ where: { name } }); + return id; + } catch (error) { + this.logger.warn( + `Compat prompt delete failed for "${name}": ${this.stringifyError(error)}` + ); + return -1; + } + } + + private resetInMemoryPrompts() { + this.cache.clear(); + this.inMemoryPrompts.clear(); + for (const prompt of prompts) { + this.inMemoryPrompts.set(prompt.name, this.clonePrompt(prompt)); + } + } + + private ensureInMemoryPrompts() { + if (!this.inMemoryPrompts.size) { + this.resetInMemoryPrompts(); + } + } + + private toDbConfig( + config: PromptConfig | null | undefined + ): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined { + if (config === null) return Prisma.DbNull; + if (config === undefined) return undefined; + return config as Prisma.InputJsonValue; + } + + private cloneMessages(messages: PromptMessage[]) { + return messages.map(message => ({ + ...message, + attachments: message.attachments ? [...message.attachments] : undefined, + params: message.params ? structuredClone(message.params) : undefined, + })); + } + + private clonePrompt(prompt: Prompt): Prompt { + return { + ...prompt, + optionalModels: prompt.optionalModels + ? [...prompt.optionalModels] + : undefined, + config: prompt.config ? structuredClone(prompt.config) : undefined, + messages: this.cloneMessages(prompt.messages), + }; + } + + private stringifyError(error: unknown) { + return error instanceof Error ? error.message : String(error); } } diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts index 483188cf69..ab3e512882 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts @@ -1,52 +1,87 @@ -import { - type AnthropicProvider as AnthropicSDKProvider, - type AnthropicProviderOptions, -} from '@ai-sdk/anthropic'; -import { type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic'; -import { AISDKError, generateText, stepCountIs, streamText } from 'ai'; - import { CopilotProviderSideError, metrics, UserFriendlyError, } from '../../../../base'; +import { + llmDispatchStream, + type NativeLlmBackendConfig, + type NativeLlmRequest, +} from '../../../../native'; +import type { NodeTextMiddleware } from '../../config'; +import type { CopilotToolSet } from '../../tools'; +import { buildNativeRequest, NativeProviderAdapter } from '../native'; import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, - CopilotProviderModel, ModelConditions, PromptMessage, StreamObject, } from '../types'; -import { ModelOutputType } from '../types'; +import { CopilotProviderType, ModelOutputType } from '../types'; import { - chatToGPTMessage, - StreamObjectParser, - TextStreamParser, + getGoogleAuth, + getVertexAnthropicBaseUrl, + type VertexAnthropicProviderConfig, } from '../utils'; export abstract class AnthropicProvider extends CopilotProvider { - protected abstract instance: - | AnthropicSDKProvider - | GoogleVertexAnthropicProvider; - private handleError(e: any) { if (e instanceof UserFriendlyError) { return e; - } else if (e instanceof AISDKError) { - this.logger.error('Throw error from ai sdk:', e); - return new CopilotProviderSideError({ - provider: this.type, - kind: e.name || 'unknown', - message: e.message, - }); - } else { - return new CopilotProviderSideError({ - provider: this.type, - kind: 'unexpected_response', - message: e?.message || 'Unexpected anthropic response', - }); } + return new CopilotProviderSideError({ + provider: this.type, + kind: 'unexpected_response', + message: e?.message || 'Unexpected anthropic response', + }); + } + + private async createNativeConfig(): Promise { + if (this.type === CopilotProviderType.AnthropicVertex) { + const config = this.config as VertexAnthropicProviderConfig; + const auth = await getGoogleAuth(config, 'anthropic'); + const { Authorization: authHeader } = auth.headers(); + const token = authHeader.replace(/^Bearer\s+/i, ''); + const baseUrl = getVertexAnthropicBaseUrl(config) || auth.baseUrl; + return { + base_url: baseUrl || '', + auth_token: token, + request_layer: 'vertex_anthropic', + headers: { Authorization: authHeader }, + }; + } + + const config = this.config as { apiKey: string; baseURL?: string }; + const baseUrl = config.baseURL || 'https://api.anthropic.com/v1'; + return { + base_url: baseUrl.replace(/\/v1\/?$/, ''), + auth_token: config.apiKey, + }; + } + + private createAdapter( + backendConfig: NativeLlmBackendConfig, + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream('anthropic', backendConfig, request, signal), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); + } + + private getReasoning( + options: NonNullable, + model: string + ): Record | undefined { + if (options.reasoning && this.isReasoningModel(model)) { + return { budget_tokens: 12000, include_thought: true }; + } + return undefined; } async text( @@ -55,32 +90,39 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs] = await chatToGPTMessage(messages, true, true); - - const modelInstance = this.instance(model.id); - const { text, reasoning } = await generateText({ - model: modelInstance, - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - anthropic: this.getAnthropicOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const reasoning = this.getReasoning(options, model.id); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + attachmentCapability: cap, + reasoning, + middleware, }); - - if (!text) throw new Error('Failed to generate text'); - - return reasoning ? `${reasoning}\n${text}` : text; + const adapter = this.createAdapter( + backendConfig, + tools, + middleware.node?.text + ); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -91,29 +133,46 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new TextStreamParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - yield result; - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } - } - if (!options.signal?.aborted) { - const footnotes = parser.end(); - if (footnotes.length) { - yield `\n\n${footnotes}`; - } + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -124,64 +183,50 @@ export abstract class AnthropicProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('chat_object_stream_calls') - .add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new StreamObjectParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - if (result) { - yield result; - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { metrics.ai .counter('chat_object_stream_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } - private async getFullStream( - model: CopilotProviderModel, - messages: PromptMessage[], - options: CopilotChatOptions = {} - ) { - const [system, msgs] = await chatToGPTMessage(messages, true, true); - const { fullStream } = streamText({ - model: this.instance(model.id), - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - anthropic: this.getAnthropicOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), - }); - return fullStream; - } - - private getAnthropicOptions(options: CopilotChatOptions, model: string) { - const result: AnthropicProviderOptions = {}; - if (options?.reasoning && this.isReasoningModel(model)) { - result.thinking = { - type: 'enabled', - budgetTokens: 12000, - }; - } - return result; - } - private isReasoningModel(model: string) { // claude 3.5 sonnet doesn't support reasoning config return model.includes('sonnet') && !model.startsWith('claude-3-5-sonnet'); diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts index 5d69aa8771..3252b32457 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts @@ -1,9 +1,6 @@ -import { - type AnthropicProvider as AnthropicSDKProvider, - createAnthropic, -} from '@ai-sdk/anthropic'; import z from 'zod'; +import { IMAGE_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { AnthropicProvider } from './anthropic'; @@ -27,6 +24,7 @@ export class AnthropicOfficialProvider extends AnthropicProvider { override readonly type = CopilotProviderType.AnthropicVertex; @@ -21,6 +21,7 @@ export class AnthropicVertexProvider extends AnthropicProvider; + isRemote: boolean; +}; + +function parseDataUrl(url: string) { + if (!url.startsWith('data:')) { + return null; + } + + const commaIndex = url.indexOf(','); + if (commaIndex === -1) { + return null; + } + + const meta = url.slice(5, commaIndex); + const payload = url.slice(commaIndex + 1); + const parts = meta.split(';'); + const mediaType = parts[0] || 'text/plain;charset=US-ASCII'; + const isBase64 = parts.includes('base64'); + + return { + mediaType, + data: isBase64 + ? payload + : Buffer.from(decodeURIComponent(payload), 'utf8').toString('base64'), + }; +} + +function attachmentTypeFromMediaType(mediaType: string): PromptAttachmentKind { + if (mediaType.startsWith('image/')) { + return 'image'; + } + if (mediaType.startsWith('audio/')) { + return 'audio'; + } + return 'file'; +} + +function attachmentKindFromHintOrMediaType( + hint: PromptAttachmentKind | undefined, + mediaType: string | undefined +): PromptAttachmentKind { + if (hint) return hint; + return attachmentTypeFromMediaType(mediaType || ''); +} + +function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') { + return encoding === 'base64' + ? data + : Buffer.from(data, 'utf8').toString('base64'); +} + +function appendAttachMetadata( + source: Record, + attachment: Exclude & Record +) { + if (attachment.fileName) { + source.file_name = attachment.fileName; + } + if (attachment.providerHint) { + source.provider_hint = attachment.providerHint; + } + return source; +} + +export function promptAttachmentHasSource( + attachment: PromptAttachment +): boolean { + if (typeof attachment === 'string') { + return !!attachment.trim(); + } + + if ('attachment' in attachment) { + return !!attachment.attachment; + } + + switch (attachment.kind) { + case 'url': + return !!attachment.url; + case 'data': + case 'bytes': + return !!attachment.data; + case 'file_handle': + return !!attachment.fileHandle; + } +} + +export async function canonicalizePromptAttachment( + attachment: PromptAttachment, + message: Pick +): Promise { + const fallbackMimeType = + typeof message.params?.mimetype === 'string' + ? message.params.mimetype + : undefined; + + if (typeof attachment === 'string') { + const dataUrl = parseDataUrl(attachment); + const mediaType = + fallbackMimeType ?? + dataUrl?.mediaType ?? + (await inferMimeType(attachment)); + const kind = attachmentKindFromHintOrMediaType(undefined, mediaType); + if (dataUrl) { + return { + kind, + sourceKind: 'data', + mediaType, + isRemote: false, + source: { + media_type: mediaType || dataUrl.mediaType, + data: dataUrl.data, + }, + }; + } + + return { + kind, + sourceKind: 'url', + mediaType, + isRemote: /^https?:\/\//.test(attachment), + source: { url: attachment, media_type: mediaType }, + }; + } + + if ('attachment' in attachment) { + return await canonicalizePromptAttachment( + { + kind: 'url', + url: attachment.attachment, + mimeType: attachment.mimeType, + }, + message + ); + } + + if (attachment.kind === 'url') { + const dataUrl = parseDataUrl(attachment.url); + const mediaType = + attachment.mimeType ?? + fallbackMimeType ?? + dataUrl?.mediaType ?? + (await inferMimeType(attachment.url)); + const kind = attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + mediaType + ); + if (dataUrl) { + return { + kind, + sourceKind: 'data', + mediaType, + isRemote: false, + source: appendAttachMetadata( + { media_type: mediaType || dataUrl.mediaType, data: dataUrl.data }, + attachment + ), + }; + } + + return { + kind, + sourceKind: 'url', + mediaType, + isRemote: /^https?:\/\//.test(attachment.url), + source: appendAttachMetadata( + { url: attachment.url, media_type: mediaType }, + attachment + ), + }; + } + + if (attachment.kind === 'data' || attachment.kind === 'bytes') { + return { + kind: attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + attachment.mimeType + ), + sourceKind: attachment.kind, + mediaType: attachment.mimeType, + isRemote: false, + source: appendAttachMetadata( + { + media_type: attachment.mimeType, + data: toBase64Data( + attachment.data, + attachment.kind === 'data' ? attachment.encoding : 'base64' + ), + }, + attachment + ), + }; + } + + return { + kind: attachmentKindFromHintOrMediaType( + attachment.providerHint?.kind, + attachment.mimeType + ), + sourceKind: 'file_handle', + mediaType: attachment.mimeType, + isRemote: false, + source: appendAttachMetadata( + { file_handle: attachment.fileHandle, media_type: attachment.mimeType }, + attachment + ), + }; +} diff --git a/packages/backend/server/src/plugins/copilot/providers/factory.ts b/packages/backend/server/src/plugins/copilot/providers/factory.ts index 0087df9118..f4ef9c983c 100644 --- a/packages/backend/server/src/plugins/copilot/providers/factory.ts +++ b/packages/backend/server/src/plugins/copilot/providers/factory.ts @@ -1,16 +1,141 @@ import { Injectable, Logger } from '@nestjs/common'; +import { Config } from '../../../base'; import { ServerFeature, ServerService } from '../../../core'; import type { CopilotProvider } from './provider'; +import { + buildProviderRegistry, + resolveModel, + stripProviderPrefix, +} from './provider-registry'; import { CopilotProviderType, ModelFullConditions } from './types'; +function isAsyncIterable(value: unknown): value is AsyncIterable { + return ( + value !== null && + value !== undefined && + typeof (value as AsyncIterable)[Symbol.asyncIterator] === + 'function' + ); +} + @Injectable() export class CopilotProviderFactory { - constructor(private readonly server: ServerService) {} + constructor( + private readonly server: ServerService, + private readonly config: Config + ) {} private readonly logger = new Logger(CopilotProviderFactory.name); - readonly #providers = new Map(); + readonly #providers = new Map(); + readonly #boundProviders = new Map(); + readonly #providerIdsByType = new Map>(); + + private getRegistry() { + return buildProviderRegistry(this.config.copilot.providers); + } + + private getPreferredProviderIds(type?: CopilotProviderType) { + if (!type) return undefined; + return this.#providerIdsByType.get(type); + } + + private normalizeCond( + providerId: string, + cond: ModelFullConditions + ): ModelFullConditions { + const registry = this.getRegistry(); + const modelId = stripProviderPrefix(registry, providerId, cond.modelId); + return { ...cond, modelId }; + } + + private normalizeMethodArgs(providerId: string, args: unknown[]) { + const [first, ...rest] = args; + if ( + !first || + typeof first !== 'object' || + Array.isArray(first) || + !('modelId' in first) + ) { + return args; + } + + const cond = first as Record; + if (typeof cond.modelId !== 'string') return args; + + const registry = this.getRegistry(); + const modelId = stripProviderPrefix(registry, providerId, cond.modelId); + return [{ ...cond, modelId }, ...rest]; + } + + private wrapAsyncIterable( + provider: CopilotProvider, + providerId: string, + iterable: AsyncIterable + ): AsyncIterableIterator { + const iterator = iterable[Symbol.asyncIterator](); + + return { + next: value => + provider.runWithProfile(providerId, () => iterator.next(value)), + return: value => + provider.runWithProfile(providerId, async () => { + if (typeof iterator.return === 'function') { + return iterator.return(value as never); + } + return { done: true, value: value as T }; + }), + throw: error => + provider.runWithProfile(providerId, async () => { + if (typeof iterator.throw === 'function') { + return iterator.throw(error); + } + throw error; + }), + [Symbol.asyncIterator]() { + return this; + }, + }; + } + + private getBoundProvider(providerId: string, provider: CopilotProvider) { + const cached = this.#boundProviders.get(providerId); + if (cached) { + return cached; + } + + const wrapped = new Proxy(provider, { + get: (target, prop, receiver) => { + if (prop === 'providerId') { + return providerId; + } + + const value = Reflect.get(target, prop, receiver); + if (typeof value !== 'function') { + return value; + } + + return (...args: unknown[]) => { + const normalizedArgs = this.normalizeMethodArgs(providerId, args); + const result = provider.runWithProfile(providerId, () => + Reflect.apply(value, provider, normalizedArgs) + ); + if (isAsyncIterable(result)) { + return this.wrapAsyncIterable( + provider, + providerId, + result as AsyncIterable + ); + } + return result; + }; + }, + }) as CopilotProvider; + + this.#boundProviders.set(providerId, wrapped); + return wrapped; + } async getProvider( cond: ModelFullConditions, @@ -21,22 +146,41 @@ export class CopilotProviderFactory { this.logger.debug( `Resolving copilot provider for output type: ${cond.outputType}` ); - let candidate: CopilotProvider | null = null; - for (const [type, provider] of this.#providers.entries()) { - if (filter.prefer && filter.prefer !== type) { + const route = resolveModel({ + registry: this.getRegistry(), + modelId: cond.modelId, + outputType: cond.outputType, + availableProviderIds: this.#providers.keys(), + preferredProviderIds: this.getPreferredProviderIds(filter.prefer), + }); + + const registry = this.getRegistry(); + for (const providerId of route.candidateProviderIds) { + const provider = this.#providers.get(providerId); + if (!provider) continue; + + const profile = registry.profiles.get(providerId); + const normalizedCond = this.normalizeCond(providerId, cond); + if ( + normalizedCond.modelId && + profile?.models?.length && + !profile.models.includes(normalizedCond.modelId) + ) { continue; } - const isMatched = await provider.match(cond); + const matched = await provider.runWithProfile(providerId, () => + provider.match(normalizedCond) + ); + if (!matched) continue; - if (isMatched) { - candidate = provider; - this.logger.debug(`Copilot provider candidate found: ${type}`); - break; - } + this.logger.debug( + `Copilot provider candidate found: ${provider.type} (${providerId})` + ); + return this.getBoundProvider(providerId, provider); } - return candidate; + return null; } async getProviderByModel( @@ -46,31 +190,50 @@ export class CopilotProviderFactory { } = {} ): Promise { this.logger.debug(`Resolving copilot provider for model: ${modelId}`); + return this.getProvider({ modelId }, filter); + } - let candidate: CopilotProvider | null = null; - for (const [type, provider] of this.#providers.entries()) { - if (filter.prefer && filter.prefer !== type) { - continue; - } - - if (await provider.match({ modelId })) { - candidate = provider; - this.logger.debug(`Copilot provider candidate found: ${type}`); + register(providerId: string, provider: CopilotProvider) { + const existed = this.#providers.get(providerId); + if (existed?.type && existed.type !== provider.type) { + const ids = this.#providerIdsByType.get(existed.type); + ids?.delete(providerId); + if (!ids?.size) { + this.#providerIdsByType.delete(existed.type); } } - return candidate; - } + this.#providers.set(providerId, provider); + this.#boundProviders.delete(providerId); - register(provider: CopilotProvider) { - this.#providers.set(provider.type, provider); - this.logger.log(`Copilot provider [${provider.type}] registered.`); + const ids = this.#providerIdsByType.get(provider.type) ?? new Set(); + ids.add(providerId); + this.#providerIdsByType.set(provider.type, ids); + + this.logger.log( + `Copilot provider [${provider.type}] registered as [${providerId}].` + ); this.server.enableFeature(ServerFeature.Copilot); } - unregister(provider: CopilotProvider) { - this.#providers.delete(provider.type); - this.logger.log(`Copilot provider [${provider.type}] unregistered.`); + unregister(providerId: string, provider: CopilotProvider) { + const existed = this.#providers.get(providerId); + if (!existed || existed !== provider) { + return; + } + + this.#providers.delete(providerId); + this.#boundProviders.delete(providerId); + + const ids = this.#providerIdsByType.get(provider.type); + ids?.delete(providerId); + if (!ids?.size) { + this.#providerIdsByType.delete(provider.type); + } + + this.logger.log( + `Copilot provider [${provider.type}] unregistered from [${providerId}].` + ); if (this.#providers.size === 0) { this.server.disableFeature(ServerFeature.Copilot); } diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 486d3e91fd..b6927a141d 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -19,6 +19,7 @@ import type { PromptMessage, } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; +import { promptAttachmentMimeType, promptAttachmentToUrl } from './utils'; export type FalConfig = { apiKey: string; @@ -183,13 +184,14 @@ export class FalProvider extends CopilotProvider { return { model_name: options.modelName || undefined, image_url: attachments - ?.map(v => - typeof v === 'string' - ? v - : v.mimeType.startsWith('image/') - ? v.attachment - : undefined - ) + ?.map(v => { + const url = promptAttachmentToUrl(v); + const mediaType = promptAttachmentMimeType( + v, + typeof params?.mimetype === 'string' ? params.mimetype : undefined + ); + return url && mediaType?.startsWith('image/') ? url : undefined; + }) .find(v => !!v), prompt: content.trim(), loras: lora.length ? lora : undefined, diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts index 8525f30355..4a5743dd7c 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts @@ -1,58 +1,94 @@ -import type { - GoogleGenerativeAIProvider, - GoogleGenerativeAIProviderOptions, -} from '@ai-sdk/google'; -import type { GoogleVertexProvider } from '@ai-sdk/google-vertex'; -import { - AISDKError, - embedMany, - generateObject, - generateText, - JSONParseError, - stepCountIs, - streamText, -} from 'ai'; +import { setTimeout as delay } from 'node:timers/promises'; + +import { ZodError } from 'zod'; import { - CopilotPromptInvalid, CopilotProviderSideError, metrics, + OneMB, + readResponseBufferWithLimit, + safeFetch, UserFriendlyError, } from '../../../../base'; +import { sniffMime } from '../../../../base/storage/providers/utils'; +import { + llmDispatchStream, + llmEmbeddingDispatch, + llmStructuredDispatch, + type NativeLlmBackendConfig, + type NativeLlmEmbeddingRequest, + type NativeLlmRequest, + type NativeLlmStructuredRequest, +} from '../../../../native'; +import type { NodeTextMiddleware } from '../../config'; +import type { CopilotToolSet } from '../../tools'; +import { + buildNativeEmbeddingRequest, + buildNativeRequest, + buildNativeStructuredRequest, + NativeProviderAdapter, + parseNativeStructuredOutput, + StructuredResponseParseError, +} from '../native'; import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, CopilotEmbeddingOptions, CopilotImageOptions, - CopilotProviderModel, + CopilotStructuredOptions, ModelConditions, + PromptAttachment, PromptMessage, StreamObject, } from '../types'; import { ModelOutputType } from '../types'; -import { - chatToGPTMessage, - StreamObjectParser, - TextStreamParser, -} from '../utils'; +import { promptAttachmentMimeType, promptAttachmentToUrl } from '../utils'; export const DEFAULT_DIMENSIONS = 256; +const GEMINI_REMOTE_ATTACHMENT_MAX_BYTES = 64 * OneMB; +const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro']; +const GEMINI_RETRY_INITIAL_DELAY_MS = 2_000; + +function normalizeMimeType(mediaType?: string) { + return mediaType?.split(';', 1)[0]?.trim() || 'application/octet-stream'; +} + +function isYoutubeUrl(url: URL) { + const hostname = url.hostname.toLowerCase(); + if (hostname === 'youtu.be') { + return /^\/[\w-]+$/.test(url.pathname); + } + + if (hostname !== 'youtube.com' && hostname !== 'www.youtube.com') { + return false; + } + + if (url.pathname !== '/watch') { + return false; + } + + return !!url.searchParams.get('v'); +} + +function isGeminiFileUrl(url: URL, baseUrl: string) { + try { + const base = new URL(baseUrl); + const basePath = base.pathname.replace(/\/+$/, ''); + return ( + url.origin === base.origin && + url.pathname.startsWith(`${basePath}/files/`) + ); + } catch { + return false; + } +} export abstract class GeminiProvider extends CopilotProvider { - protected abstract instance: - | GoogleGenerativeAIProvider - | GoogleVertexProvider; + protected abstract createNativeConfig(): Promise; private handleError(e: any) { if (e instanceof UserFriendlyError) { return e; - } else if (e instanceof AISDKError) { - this.logger.error('Throw error from ai sdk:', e); - return new CopilotProviderSideError({ - provider: this.type, - kind: e.name || 'unknown', - message: e.message, - }); } else { return new CopilotProviderSideError({ provider: this.type, @@ -62,37 +98,261 @@ export abstract class GeminiProvider extends CopilotProvider { } } + protected createNativeDispatch(backendConfig: NativeLlmBackendConfig) { + return (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream('gemini', backendConfig, request, signal); + } + + protected createNativeStructuredDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmStructuredRequest) => + llmStructuredDispatch('gemini', backendConfig, request); + } + + protected createNativeEmbeddingDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmEmbeddingRequest) => + llmEmbeddingDispatch('gemini', backendConfig, request); + } + + protected createNativeAdapter( + backendConfig: NativeLlmBackendConfig, + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + this.createNativeDispatch(backendConfig), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); + } + + protected async fetchRemoteAttach(url: string, signal?: AbortSignal) { + const parsed = new URL(url); + const response = await safeFetch( + parsed, + { method: 'GET', signal }, + this.buildAttachFetchOptions(parsed) + ); + if (!response.ok) { + throw new Error( + `Failed to fetch attachment: ${response.status} ${response.statusText}` + ); + } + const buffer = await readResponseBufferWithLimit( + response, + GEMINI_REMOTE_ATTACHMENT_MAX_BYTES + ); + const headerMimeType = normalizeMimeType( + response.headers.get('content-type') || '' + ); + return { + data: buffer.toString('base64'), + mimeType: normalizeMimeType(sniffMime(buffer, headerMimeType)), + }; + } + + private buildAttachFetchOptions(url: URL) { + const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const; + if (!env.prod) { + return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) }; + } + + const trustedOrigins = new Set(); + const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:'; + const port = this.AFFiNEConfig.server.port; + const isDefaultPort = + (protocol === 'https:' && port === 443) || + (protocol === 'http:' && port === 80); + + const addHostOrigin = (host: string) => { + if (!host) return; + try { + const parsed = new URL(`${protocol}//${host}`); + if (!parsed.port && !isDefaultPort) { + parsed.port = String(port); + } + trustedOrigins.add(parsed.origin); + } catch { + // ignore invalid host config entries + } + }; + + if (this.AFFiNEConfig.server.externalUrl) { + try { + trustedOrigins.add( + new URL(this.AFFiNEConfig.server.externalUrl).origin + ); + } catch { + // ignore invalid external URL + } + } + + addHostOrigin(this.AFFiNEConfig.server.host); + for (const host of this.AFFiNEConfig.server.hosts) { + addHostOrigin(host); + } + + const hostname = url.hostname.toLowerCase(); + const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some( + suffix => hostname === suffix || hostname.endsWith(`.${suffix}`) + ); + if (trustedOrigins.has(url.origin) || trustedByHost) { + return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) }; + } + + return baseOptions; + } + + private shouldInlineRemoteAttach(url: URL, config: NativeLlmBackendConfig) { + switch (config.request_layer) { + case 'gemini_api': + if (url.protocol !== 'http:' && url.protocol !== 'https:') return false; + return !(isGeminiFileUrl(url, config.base_url) || isYoutubeUrl(url)); + case 'gemini_vertex': + return false; + default: + return false; + } + } + + private toInlineAttach( + attachment: PromptAttachment, + mimeType: string, + data: string + ): PromptAttachment { + if (typeof attachment === 'string' || !('kind' in attachment)) { + return { kind: 'bytes', data, mimeType }; + } + + if (attachment.kind !== 'url') { + return attachment; + } + + return { + kind: 'bytes', + data, + mimeType, + fileName: attachment.fileName, + providerHint: attachment.providerHint, + }; + } + + protected async prepareMessages( + messages: PromptMessage[], + backendConfig: NativeLlmBackendConfig, + signal?: AbortSignal + ): Promise { + const prepared: PromptMessage[] = []; + + for (const message of messages) { + signal?.throwIfAborted(); + if (!Array.isArray(message.attachments) || !message.attachments.length) { + prepared.push(message); + continue; + } + + const attachments: PromptAttachment[] = []; + let changed = false; + for (const attachment of message.attachments) { + signal?.throwIfAborted(); + const rawUrl = promptAttachmentToUrl(attachment); + if (!rawUrl || rawUrl.startsWith('data:')) { + attachments.push(attachment); + continue; + } + + let parsed: URL; + try { + parsed = new URL(rawUrl); + } catch { + attachments.push(attachment); + continue; + } + + if (!this.shouldInlineRemoteAttach(parsed, backendConfig)) { + attachments.push(attachment); + continue; + } + + const declaredMimeType = promptAttachmentMimeType( + attachment, + typeof message.params?.mimetype === 'string' + ? message.params.mimetype + : undefined + ); + const downloaded = await this.fetchRemoteAttach(rawUrl, signal); + attachments.push( + this.toInlineAttach( + attachment, + declaredMimeType + ? normalizeMimeType(declaredMimeType) + : downloaded.mimeType, + downloaded.data + ) + ); + changed = true; + } + + prepared.push(changed ? { ...message, attachments } : message); + } + + return prepared; + } + + protected async waitForStructuredRetry( + delayMs: number, + signal?: AbortSignal + ) { + await delay(delayMs, undefined, signal ? { signal } : undefined); + } + async text( cond: ModelConditions, messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs] = await chatToGPTMessage(messages); - - const modelInstance = this.instance(model.id); - const { text } = await generateText({ - model: modelInstance, - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - google: this.getGeminiOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages: msg, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, }); - - if (!text) throw new Error('Failed to generate text'); - return text.trim(); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -100,55 +360,65 @@ export abstract class GeminiProvider extends CopilotProvider { override async structure( cond: ModelConditions, messages: PromptMessage[], - options: CopilotChatOptions = {} + options: CopilotStructuredOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Structured }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs, schema] = await chatToGPTMessage(messages); - if (!schema) { - throw new CopilotPromptInvalid('Schema is required'); - } - - const modelInstance = this.instance(model.id); - const { object } = await generateObject({ - model: modelInstance, - system, - messages: msgs, - schema, - providerOptions: { - google: { - thinkingConfig: { - thinkingBudget: -1, - includeThoughts: false, - }, - }, - }, - abortSignal: options.signal, - maxRetries: options.maxRetries || 3, - experimental_repairText: async ({ text, error }) => { - if (error instanceof JSONParseError) { - // strange fixed response, temporarily replace it - const ret = text.replaceAll(/^ny\n/g, ' ').trim(); - if (ret.startsWith('```') || ret.endsWith('```')) { - return ret - .replace(/```[\w\s]+\n/g, '') - .replace(/\n```/g, '') - .trim(); - } - return ret; - } - return null; - }, + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const structuredDispatch = + this.createNativeStructuredDispatch(backendConfig); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Structured); + const { request, schema } = await buildNativeStructuredRequest({ + model: model.id, + messages: msg, + options, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + responseSchema: options.schema, + middleware, }); - - return JSON.stringify(object); + const maxRetries = Math.max(options.maxRetries ?? 3, 0); + for (let attempt = 0; ; attempt++) { + try { + const response = await structuredDispatch(request); + const parsed = parseNativeStructuredOutput(response); + const validated = schema.parse(parsed); + return JSON.stringify(validated); + } catch (error) { + const isParsingError = + error instanceof StructuredResponseParseError || + error instanceof ZodError; + const retryableError = + isParsingError || !(error instanceof UserFriendlyError); + if (!retryableError || attempt >= maxRetries) { + throw error; + } + if (!isParsingError) { + await this.waitForStructuredRetry( + GEMINI_RETRY_INITIAL_DELAY_MS * 2 ** attempt, + options.signal + ); + } + } + } } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -159,29 +429,54 @@ export abstract class GeminiProvider extends CopilotProvider { options: CopilotChatOptions | CopilotImageOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new TextStreamParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - yield result; - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } - } - if (!options.signal?.aborted) { - const footnotes = parser.end(); - if (footnotes.length) { - yield `\n\n${footnotes}`; - } + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const preparedMessages = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools( + options as CopilotChatOptions, + model.id + ); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const { request } = await buildNativeRequest({ + model: model.id, + messages: preparedMessages, + options: options as CopilotChatOptions, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -192,29 +487,51 @@ export abstract class GeminiProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('chat_object_stream_calls') - .add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new StreamObjectParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - if (result) { - yield result; - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const msg = await this.prepareMessages( + messages, + backendConfig, + options.signal + ); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); + const { request } = await buildNativeRequest({ + model: model.id, + messages: msg, + options, + tools, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter( + backendConfig, + tools, + middleware.node?.text + ); + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { metrics.ai .counter('chat_object_stream_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -224,77 +541,60 @@ export abstract class GeminiProvider extends CopilotProvider { messages: string | string[], options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { - messages = Array.isArray(messages) ? messages : [messages]; + const values = Array.isArray(messages) ? messages : [messages]; const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; - await this.checkParams({ embeddings: messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + embeddings: values, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('generate_embedding_calls') - .add(1, { model: model.id }); - - const modelInstance = this.instance.textEmbeddingModel(model.id); - - const embeddings = await Promise.allSettled( - messages.map(m => - embedMany({ - model: modelInstance, - values: [m], - maxRetries: 3, - providerOptions: { - google: { - outputDimensionality: options.dimensions || DEFAULT_DIMENSIONS, - taskType: 'RETRIEVAL_DOCUMENT', - }, - }, - }) - ) + .add(1, this.metricLabels(model.id)); + const backendConfig = await this.createNativeConfig(); + const response = await this.createNativeEmbeddingDispatch(backendConfig)( + buildNativeEmbeddingRequest({ + model: model.id, + inputs: values, + dimensions: options.dimensions || DEFAULT_DIMENSIONS, + taskType: 'RETRIEVAL_DOCUMENT', + }) ); - - return embeddings - .flatMap(e => (e.status === 'fulfilled' ? e.value.embeddings : null)) - .filter((v): v is number[] => !!v && Array.isArray(v)); + return response.embeddings; } catch (e: any) { metrics.ai .counter('generate_embedding_errors') - .add(1, { model: model.id }); + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } - private async getFullStream( - model: CopilotProviderModel, - messages: PromptMessage[], - options: CopilotChatOptions = {} - ) { - const [system, msgs] = await chatToGPTMessage(messages); - const { fullStream } = streamText({ - model: this.instance(model.id), - system, - messages: msgs, - abortSignal: options.signal, - providerOptions: { - google: this.getGeminiOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), - }); - return fullStream; + protected getReasoning( + options: CopilotChatOptions | CopilotImageOptions, + model: string + ): Record | undefined { + if ( + options && + 'reasoning' in options && + options.reasoning && + this.isReasoningModel(model) + ) { + return this.isGemini3Model(model) + ? { include_thoughts: true, thinking_level: 'high' } + : { include_thoughts: true, thinking_budget: 12000 }; + } + + return undefined; } - private getGeminiOptions(options: CopilotChatOptions, model: string) { - const result: GoogleGenerativeAIProviderOptions = {}; - if (options?.reasoning && this.isReasoningModel(model)) { - result.thinkingConfig = { - thinkingBudget: 12000, - includeThoughts: true, - }; - } - return result; + private isGemini3Model(model: string) { + return model.startsWith('gemini-3'); } private isReasoningModel(model: string) { - return model.startsWith('gemini-2.5'); + return model.startsWith('gemini-2.5') || this.isGemini3Model(model); } } diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts index 66eb289528..d29c5a98ba 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts @@ -1,9 +1,7 @@ -import { - createGoogleGenerativeAI, - type GoogleGenerativeAIProvider, -} from '@ai-sdk/google'; import z from 'zod'; +import type { NativeLlmBackendConfig } from '../../../../native'; +import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; import { GeminiProvider } from './gemini'; @@ -20,25 +18,6 @@ export class GeminiGenerativeProvider extends GeminiProvider { + return { + base_url: ( + this.config.baseURL || + 'https://generativelanguage.googleapis.com/v1beta' + ).replace(/\/$/, ''), + auth_token: this.config.apiKey, + request_layer: 'gemini_api', + }; + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts index 24e05ea242..e9ef0735a2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts @@ -1,14 +1,14 @@ -import { - createVertex, - type GoogleVertexProvider, - type GoogleVertexProviderSettings, -} from '@ai-sdk/google-vertex'; - +import type { NativeLlmBackendConfig } from '../../../../native'; +import { GEMINI_ATTACHMENT_CAPABILITY } from '../attachments'; import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; -import { getGoogleAuth, VertexModelListSchema } from '../utils'; +import { + getGoogleAuth, + VertexModelListSchema, + type VertexProviderConfig, +} from '../utils'; import { GeminiProvider } from './gemini'; -export type GeminiVertexConfig = GoogleVertexProviderSettings; +export type GeminiVertexConfig = VertexProviderConfig; export class GeminiVertexProvider extends GeminiProvider { override readonly type = CopilotProviderType.GeminiVertex; @@ -23,12 +23,15 @@ export class GeminiVertexProvider extends GeminiProvider { ModelInputType.Text, ModelInputType.Image, ModelInputType.Audio, + ModelInputType.File, ], output: [ ModelOutputType.Text, ModelOutputType.Object, ModelOutputType.Structured, ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, }, ], }, @@ -41,12 +44,36 @@ export class GeminiVertexProvider extends GeminiProvider { ModelInputType.Text, ModelInputType.Image, ModelInputType.Audio, + ModelInputType.File, ], output: [ ModelOutputType.Text, ModelOutputType.Object, ModelOutputType.Structured, ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, + }, + ], + }, + { + name: 'Gemini 3.1 Pro Preview', + id: 'gemini-3.1-pro-preview', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ModelInputType.File, + ], + output: [ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ], + attachments: GEMINI_ATTACHMENT_CAPABILITY, + structuredAttachments: GEMINI_ATTACHMENT_CAPABILITY, }, ], }, @@ -62,21 +89,13 @@ export class GeminiVertexProvider extends GeminiProvider { ], }, ]; - - protected instance!: GoogleVertexProvider; - override configured(): boolean { return !!this.config.location && !!this.config.googleAuthOptions; } - protected override setup() { - super.setup(); - this.instance = createVertex(this.config); - } - override async refreshOnlineModels() { try { - const { baseUrl, headers } = await getGoogleAuth(this.config, 'google'); + const { baseUrl, headers } = await this.resolveVertexAuth(); if (baseUrl && !this.onlineModelList.length) { const { publisherModels } = await fetch(`${baseUrl}/models`, { headers: headers(), @@ -91,4 +110,19 @@ export class GeminiVertexProvider extends GeminiProvider { this.logger.error('Failed to fetch available models', e); } } + + protected async resolveVertexAuth() { + return await getGoogleAuth(this.config, 'google'); + } + + protected override async createNativeConfig(): Promise { + const auth = await this.resolveVertexAuth(); + const { Authorization: authHeader } = auth.headers(); + + return { + base_url: auth.baseUrl || '', + auth_token: authHeader.replace(/^Bearer\s+/i, ''), + request_layer: 'gemini_vertex', + }; + } } diff --git a/packages/backend/server/src/plugins/copilot/providers/loop.ts b/packages/backend/server/src/plugins/copilot/providers/loop.ts new file mode 100644 index 0000000000..24e26522aa --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/loop.ts @@ -0,0 +1,479 @@ +import { z } from 'zod'; + +import type { + NativeLlmRequest, + NativeLlmStreamEvent, + NativeLlmToolDefinition, +} from '../../../native'; +import type { + CopilotTool, + CopilotToolExecuteOptions, + CopilotToolSet, +} from '../tools'; + +export type NativeDispatchFn = ( + request: NativeLlmRequest, + signal?: AbortSignal +) => AsyncIterableIterator; + +export type NativeToolCall = { + id: string; + name: string; + args: Record; + rawArgumentsText?: string; + argumentParseError?: string; + thought?: string; +}; + +type ToolCallState = { + name?: string; + argumentsText: string; +}; + +type ToolExecutionResult = { + callId: string; + name: string; + args: Record; + rawArgumentsText?: string; + argumentParseError?: string; + output: unknown; + isError?: boolean; +}; + +type ParsedToolArguments = { + args: Record; + rawArgumentsText?: string; + argumentParseError?: string; +}; + +export class ToolCallAccumulator { + readonly #states = new Map(); + + feedDelta(event: Extract) { + const state = this.#states.get(event.call_id) ?? { + argumentsText: '', + }; + if (event.name) { + state.name = event.name; + } + if (event.arguments_delta) { + state.argumentsText += event.arguments_delta; + } + this.#states.set(event.call_id, state); + } + + complete(event: Extract) { + const state = this.#states.get(event.call_id); + this.#states.delete(event.call_id); + const parsed = + event.arguments_text !== undefined || event.arguments_error !== undefined + ? { + args: event.arguments ?? {}, + rawArgumentsText: event.arguments_text ?? state?.argumentsText, + argumentParseError: event.arguments_error, + } + : event.arguments + ? this.parseArgs(event.arguments, state?.argumentsText) + : this.parseJson(state?.argumentsText ?? '{}'); + return { + id: event.call_id, + name: event.name || state?.name || '', + ...parsed, + thought: event.thought, + } satisfies NativeToolCall; + } + + drainPending() { + const pending: NativeToolCall[] = []; + for (const [callId, state] of this.#states.entries()) { + if (!state.name) { + continue; + } + pending.push({ + id: callId, + name: state.name, + ...this.parseJson(state.argumentsText), + }); + } + this.#states.clear(); + return pending; + } + + private parseJson(jsonText: string): ParsedToolArguments { + if (!jsonText.trim()) { + return { args: {} }; + } + try { + return this.parseArgs(JSON.parse(jsonText), jsonText); + } catch (error) { + return { + args: {}, + rawArgumentsText: jsonText, + argumentParseError: + error instanceof Error + ? error.message + : 'Invalid tool arguments JSON', + }; + } + } + + private parseArgs( + value: unknown, + rawArgumentsText?: string + ): ParsedToolArguments { + if (value && typeof value === 'object' && !Array.isArray(value)) { + return { + args: value as Record, + rawArgumentsText, + }; + } + return { + args: {}, + rawArgumentsText, + argumentParseError: 'Tool arguments must be a JSON object', + }; + } +} + +export class ToolSchemaExtractor { + static extract(toolSet: CopilotToolSet): NativeLlmToolDefinition[] { + return Object.entries(toolSet).map(([name, tool]) => { + return { + name, + description: tool.description, + parameters: this.toJsonSchema(tool.inputSchema ?? z.object({})), + }; + }); + } + + static toJsonSchema(schema: unknown): Record { + if (!(schema instanceof z.ZodType)) { + if (schema && typeof schema === 'object' && !Array.isArray(schema)) { + return schema as Record; + } + return { type: 'object', properties: {} }; + } + + if (schema instanceof z.ZodObject) { + const shape = schema.shape; + const properties: Record = {}; + const required: string[] = []; + + for (const [key, child] of Object.entries( + shape as Record + )) { + properties[key] = this.toJsonSchema(child); + if (!this.isOptional(child)) { + required.push(key); + } + } + + return { + type: 'object', + properties, + additionalProperties: false, + ...(required.length ? { required } : {}), + }; + } + + if (schema instanceof z.ZodString) { + return { type: 'string' }; + } + if (schema instanceof z.ZodNumber) { + return { type: 'number' }; + } + if (schema instanceof z.ZodBoolean) { + return { type: 'boolean' }; + } + if (schema instanceof z.ZodArray) { + return { type: 'array', items: this.toJsonSchema(schema.element) }; + } + if (schema instanceof z.ZodEnum) { + return { type: 'string', enum: schema.options }; + } + if (schema instanceof z.ZodLiteral) { + const literal = schema.value; + if (literal === null) { + return { const: null, type: 'null' }; + } + if (typeof literal === 'string') { + return { const: literal, type: 'string' }; + } + if (typeof literal === 'number') { + return { const: literal, type: 'number' }; + } + if (typeof literal === 'boolean') { + return { const: literal, type: 'boolean' }; + } + return { const: literal }; + } + if (schema instanceof z.ZodUnion) { + return { + anyOf: schema.options.map((option: z.ZodTypeAny) => + this.toJsonSchema(option) + ), + }; + } + if (schema instanceof z.ZodRecord) { + return { + type: 'object', + additionalProperties: this.toJsonSchema(schema.valueSchema), + }; + } + + if (schema instanceof z.ZodNullable) { + const inner = (schema._def as { innerType?: z.ZodTypeAny }).innerType; + return { anyOf: [this.toJsonSchema(inner), { type: 'null' }] }; + } + + if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) { + return this.toJsonSchema( + (schema._def as { innerType?: z.ZodTypeAny }).innerType + ); + } + + if (schema instanceof z.ZodEffects) { + return this.toJsonSchema( + (schema._def as { schema?: z.ZodTypeAny }).schema + ); + } + + return { type: 'object', properties: {} }; + } + + private static isOptional(schema: z.ZodTypeAny): boolean { + if (schema instanceof z.ZodOptional || schema instanceof z.ZodDefault) { + return true; + } + if (schema instanceof z.ZodNullable) { + return this.isOptional( + (schema._def as { innerType: z.ZodTypeAny }).innerType + ); + } + if (schema instanceof z.ZodEffects) { + return this.isOptional((schema._def as { schema: z.ZodTypeAny }).schema); + } + return false; + } +} + +export class ToolCallLoop { + constructor( + private readonly dispatch: NativeDispatchFn, + private readonly tools: CopilotToolSet, + private readonly maxSteps = 20 + ) {} + + private normalizeToolExecuteOptions( + signalOrOptions?: AbortSignal | CopilotToolExecuteOptions, + maybeMessages?: CopilotToolExecuteOptions['messages'] + ): CopilotToolExecuteOptions { + if ( + signalOrOptions && + typeof signalOrOptions === 'object' && + 'aborted' in signalOrOptions + ) { + return { + signal: signalOrOptions, + messages: maybeMessages, + }; + } + + if (!signalOrOptions) { + return maybeMessages ? { messages: maybeMessages } : {}; + } + + return { + ...signalOrOptions, + signal: signalOrOptions.signal, + messages: signalOrOptions.messages ?? maybeMessages, + }; + } + + async *run( + request: NativeLlmRequest, + signalOrOptions?: AbortSignal | CopilotToolExecuteOptions, + maybeMessages?: CopilotToolExecuteOptions['messages'] + ): AsyncIterableIterator { + const toolExecuteOptions = this.normalizeToolExecuteOptions( + signalOrOptions, + maybeMessages + ); + const messages = request.messages.map(message => ({ + ...message, + content: [...message.content], + })); + + for (let step = 0; step < this.maxSteps; step++) { + const toolCalls: NativeToolCall[] = []; + const accumulator = new ToolCallAccumulator(); + let finalDone: Extract | null = + null; + + for await (const event of this.dispatch( + { + ...request, + stream: true, + messages, + }, + toolExecuteOptions.signal + )) { + switch (event.type) { + case 'tool_call_delta': { + accumulator.feedDelta(event); + break; + } + case 'tool_call': { + toolCalls.push(accumulator.complete(event)); + yield event; + break; + } + case 'done': { + finalDone = event; + break; + } + case 'error': { + throw new Error(event.message); + } + default: { + yield event; + break; + } + } + } + + toolCalls.push(...accumulator.drainPending()); + if (toolCalls.length === 0) { + if (finalDone) { + yield finalDone; + } + break; + } + + if (step === this.maxSteps - 1) { + throw new Error('ToolCallLoop max steps reached'); + } + + const toolResults = await this.executeTools( + toolCalls, + toolExecuteOptions + ); + + messages.push({ + role: 'assistant', + content: toolCalls.map(call => ({ + type: 'tool_call', + call_id: call.id, + name: call.name, + arguments: call.args, + arguments_text: call.rawArgumentsText, + arguments_error: call.argumentParseError, + thought: call.thought, + })), + }); + + for (const result of toolResults) { + messages.push({ + role: 'tool', + content: [ + { + type: 'tool_result', + call_id: result.callId, + name: result.name, + arguments: result.args, + arguments_text: result.rawArgumentsText, + arguments_error: result.argumentParseError, + output: result.output, + is_error: result.isError, + }, + ], + }); + yield { + type: 'tool_result', + call_id: result.callId, + name: result.name, + arguments: result.args, + arguments_text: result.rawArgumentsText, + arguments_error: result.argumentParseError, + output: result.output, + is_error: result.isError, + }; + } + } + } + + private async executeTools( + calls: NativeToolCall[], + options: CopilotToolExecuteOptions + ) { + return await Promise.all( + calls.map(call => this.executeTool(call, options)) + ); + } + + private async executeTool( + call: NativeToolCall, + options: CopilotToolExecuteOptions + ): Promise { + const tool = this.tools[call.name] as CopilotTool | undefined; + + if (!tool?.execute) { + return { + callId: call.id, + name: call.name, + args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, + isError: true, + output: { + message: `Tool not found: ${call.name}`, + }, + }; + } + + if (call.argumentParseError) { + return { + callId: call.id, + name: call.name, + args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, + isError: true, + output: { + message: 'Invalid tool arguments JSON', + rawArguments: call.rawArgumentsText, + error: call.argumentParseError, + }, + }; + } + + try { + const output = await tool.execute(call.args, options); + return { + callId: call.id, + name: call.name, + args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, + output: output ?? null, + }; + } catch (error) { + console.error('Tool execution failed', { + callId: call.id, + toolName: call.name, + error, + }); + return { + callId: call.id, + name: call.name, + args: call.args, + rawArgumentsText: call.rawArgumentsText, + argumentParseError: call.argumentParseError, + isError: true, + output: { + message: 'Tool execution failed', + }, + }; + } + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/morph.ts b/packages/backend/server/src/plugins/copilot/providers/morph.ts index 36f96c9f26..a4298315a2 100644 --- a/packages/backend/server/src/plugins/copilot/providers/morph.ts +++ b/packages/backend/server/src/plugins/copilot/providers/morph.ts @@ -1,14 +1,16 @@ -import { - createOpenAICompatible, - OpenAICompatibleProvider as VercelOpenAICompatibleProvider, -} from '@ai-sdk/openai-compatible'; -import { AISDKError, generateText, streamText } from 'ai'; - import { CopilotProviderSideError, metrics, UserFriendlyError, } from '../../../base'; +import { + llmDispatchStream, + type NativeLlmBackendConfig, + type NativeLlmRequest, +} from '../../../native'; +import type { NodeTextMiddleware } from '../config'; +import type { CopilotToolSet } from '../tools'; +import { buildNativeRequest, NativeProviderAdapter } from './native'; import { CopilotProvider } from './provider'; import type { CopilotChatOptions, @@ -16,7 +18,6 @@ import type { PromptMessage, } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; -import { chatToGPTMessage, TextStreamParser } from './utils'; export const DEFAULT_DIMENSIONS = 256; @@ -57,37 +58,48 @@ export class MorphProvider extends CopilotProvider { }, ]; - #instance!: VercelOpenAICompatibleProvider; - override configured(): boolean { return !!this.config.apiKey; } protected override setup() { super.setup(); - this.#instance = createOpenAICompatible({ - name: this.type, - apiKey: this.config.apiKey, - baseURL: 'https://api.morphllm.com/v1', - }); } private handleError(e: any) { if (e instanceof UserFriendlyError) { return e; - } else if (e instanceof AISDKError) { - return new CopilotProviderSideError({ - provider: this.type, - kind: e.name || 'unknown', - message: e.message, - }); - } else { - return new CopilotProviderSideError({ - provider: this.type, - kind: 'unexpected_response', - message: e?.message || 'Unexpected morph response', - }); } + return new CopilotProviderSideError({ + provider: this.type, + kind: 'unexpected_response', + message: e?.message || 'Unexpected morph response', + }); + } + + private createNativeConfig(): NativeLlmBackendConfig { + return { + base_url: 'https://api.morphllm.com', + auth_token: this.config.apiKey ?? '', + }; + } + + private createNativeAdapter( + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream( + 'openai_chat', + this.createNativeConfig(), + request, + signal + ), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); } async text( @@ -95,30 +107,32 @@ export class MorphProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { - const fullCond = { - ...cond, - outputType: ModelOutputType.Text, - }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const model = this.selectModel( + await this.checkParams({ + messages, + cond: fullCond, + options, + }) + ); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs] = await chatToGPTMessage(messages); - - const modelInstance = this.#instance(model.id); - - const { text } = await generateText({ - model: modelInstance, - system, - messages: msgs, - abortSignal: options.signal, + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + middleware, }); - - return text.trim(); + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -128,46 +142,40 @@ export class MorphProvider extends CopilotProvider { messages: PromptMessage[], options: CopilotChatOptions = {} ): AsyncIterable { - const fullCond = { - ...cond, - outputType: ModelOutputType.Text, - }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const fullCond = { ...cond, outputType: ModelOutputType.Text }; + const model = this.selectModel( + await this.checkParams({ + messages, + cond: fullCond, + options, + }) + ); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); - const [system, msgs] = await chatToGPTMessage(messages); - - const modelInstance = this.#instance(model.id); - - const { fullStream } = streamText({ - model: modelInstance, - system, - messages: msgs, - abortSignal: options.signal, + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + middleware, }); - - const textParser = new TextStreamParser(); - for await (const chunk of fullStream) { - switch (chunk.type) { - case 'text-delta': { - let result = textParser.parse(chunk); - yield result; - break; - } - default: { - yield textParser.parse(chunk); - break; - } - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } diff --git a/packages/backend/server/src/plugins/copilot/providers/native.ts b/packages/backend/server/src/plugins/copilot/providers/native.ts new file mode 100644 index 0000000000..355e68d735 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/native.ts @@ -0,0 +1,675 @@ +import { ZodType } from 'zod'; + +import { CopilotPromptInvalid } from '../../../base'; +import type { + NativeLlmCoreContent, + NativeLlmCoreMessage, + NativeLlmEmbeddingRequest, + NativeLlmRequest, + NativeLlmStreamEvent, + NativeLlmStructuredRequest, + NativeLlmStructuredResponse, +} from '../../../native'; +import type { NodeTextMiddleware, ProviderMiddlewareConfig } from '../config'; +import type { CopilotToolSet } from '../tools'; +import { + canonicalizePromptAttachment, + type CanonicalPromptAttachment, +} from './attachments'; +import { NativeDispatchFn, ToolCallLoop, ToolSchemaExtractor } from './loop'; +import type { + CopilotChatOptions, + CopilotStructuredOptions, + ModelAttachmentCapability, + PromptMessage, + StreamObject, +} from './types'; +import { CitationFootnoteFormatter, TextStreamParser } from './utils'; + +type BuildNativeRequestOptions = { + model: string; + messages: PromptMessage[]; + options?: CopilotChatOptions | CopilotStructuredOptions; + tools?: CopilotToolSet; + withAttachment?: boolean; + attachmentCapability?: ModelAttachmentCapability; + include?: string[]; + reasoning?: Record; + responseSchema?: unknown; + middleware?: ProviderMiddlewareConfig; +}; + +type BuildNativeRequestResult = { + request: NativeLlmRequest; + schema?: ZodType; +}; + +type BuildNativeStructuredRequestResult = { + request: NativeLlmStructuredRequest; + schema: ZodType; +}; + +type ToolCallMeta = { + name: string; + args: Record; +}; + +type NormalizedToolResultEvent = Extract< + NativeLlmStreamEvent, + { type: 'tool_result' } +> & { + name: string; + arguments: Record; +}; + +type AttachmentFootnote = { + blobId: string; + fileName: string; + fileType: string; +}; + +type NativeProviderAdapterOptions = { + nodeTextMiddleware?: NodeTextMiddleware[]; +}; + +function roleToCore(role: PromptMessage['role']) { + switch (role) { + case 'assistant': + return 'assistant'; + case 'system': + return 'system'; + default: + return 'user'; + } +} + +function ensureAttachmentSupported( + attachment: CanonicalPromptAttachment, + attachmentCapability?: ModelAttachmentCapability +) { + if (!attachmentCapability) return; + + if (!attachmentCapability.kinds.includes(attachment.kind)) { + throw new CopilotPromptInvalid( + `Native path does not support ${attachment.kind} attachments${ + attachment.mediaType ? ` (${attachment.mediaType})` : '' + }` + ); + } + + if ( + attachmentCapability.sourceKinds?.length && + !attachmentCapability.sourceKinds.includes(attachment.sourceKind) + ) { + throw new CopilotPromptInvalid( + `Native path does not support ${attachment.sourceKind} attachment sources` + ); + } + + if (attachment.isRemote && attachmentCapability.allowRemoteUrls === false) { + throw new CopilotPromptInvalid( + 'Native path does not support remote attachment urls' + ); + } +} + +function resolveResponseSchema( + systemMessage: PromptMessage | undefined, + responseSchema?: unknown +): ZodType | undefined { + if (responseSchema instanceof ZodType) { + return responseSchema; + } + + if (systemMessage?.responseFormat?.schema instanceof ZodType) { + return systemMessage.responseFormat.schema; + } + + return systemMessage?.params?.schema instanceof ZodType + ? systemMessage.params.schema + : undefined; +} + +function resolveResponseStrict( + systemMessage: PromptMessage | undefined, + options?: CopilotStructuredOptions +) { + return options?.strict ?? systemMessage?.responseFormat?.strict ?? true; +} + +export class StructuredResponseParseError extends Error {} + +function normalizeStructuredText(text: string) { + const trimmed = text.replaceAll(/^ny\n/g, ' ').trim(); + if (trimmed.startsWith('```') || trimmed.endsWith('```')) { + return trimmed + .replace(/```[\w\s-]*\n/g, '') + .replace(/\n```/g, '') + .trim(); + } + return trimmed; +} + +export function parseNativeStructuredOutput( + response: Pick & { + output_json?: unknown; + } +) { + if (response.output_json !== undefined) { + return response.output_json; + } + + const normalized = normalizeStructuredText(response.output_text); + const candidates = [ + () => normalized, + () => { + const objectStart = normalized.indexOf('{'); + const objectEnd = normalized.lastIndexOf('}'); + return objectStart !== -1 && objectEnd > objectStart + ? normalized.slice(objectStart, objectEnd + 1) + : null; + }, + () => { + const arrayStart = normalized.indexOf('['); + const arrayEnd = normalized.lastIndexOf(']'); + return arrayStart !== -1 && arrayEnd > arrayStart + ? normalized.slice(arrayStart, arrayEnd + 1) + : null; + }, + ]; + + for (const candidate of candidates) { + try { + const candidateText = candidate(); + if (typeof candidateText === 'string') { + return JSON.parse(candidateText); + } + } catch { + continue; + } + } + + throw new StructuredResponseParseError( + `Unexpected structured response: ${normalized.slice(0, 200)}` + ); +} + +async function toCoreContents( + message: PromptMessage, + withAttachment: boolean, + attachmentCapability?: ModelAttachmentCapability +): Promise { + const contents: NativeLlmCoreContent[] = []; + + if (typeof message.content === 'string' && message.content.length) { + contents.push({ type: 'text', text: message.content }); + } + + if (!withAttachment || !Array.isArray(message.attachments)) return contents; + + for (const entry of message.attachments) { + const normalized = await canonicalizePromptAttachment(entry, message); + ensureAttachmentSupported(normalized, attachmentCapability); + contents.push({ + type: normalized.kind, + source: normalized.source, + }); + } + + return contents; +} + +export async function buildNativeRequest({ + model, + messages, + options = {}, + tools = {}, + withAttachment = true, + attachmentCapability, + include, + reasoning, + responseSchema, + middleware, +}: BuildNativeRequestOptions): Promise { + const copiedMessages = messages.map(message => ({ + ...message, + attachments: message.attachments + ? [...message.attachments] + : message.attachments, + })); + + const systemMessage = + copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined; + const schema = resolveResponseSchema(systemMessage, responseSchema); + + const coreMessages: NativeLlmCoreMessage[] = []; + if (systemMessage?.content?.length) { + coreMessages.push({ + role: 'system', + content: [{ type: 'text', text: systemMessage.content }], + }); + } + + for (const message of copiedMessages) { + if (message.role === 'system') continue; + const content = await toCoreContents( + message, + withAttachment, + attachmentCapability + ); + coreMessages.push({ role: roleToCore(message.role), content }); + } + + return { + request: { + model, + stream: true, + messages: coreMessages, + max_tokens: options.maxTokens ?? undefined, + temperature: options.temperature ?? undefined, + tools: ToolSchemaExtractor.extract(tools), + tool_choice: Object.keys(tools).length ? 'auto' : undefined, + include, + reasoning, + response_schema: schema + ? ToolSchemaExtractor.toJsonSchema(schema) + : undefined, + middleware: middleware?.rust + ? { request: middleware.rust.request, stream: middleware.rust.stream } + : undefined, + }, + schema, + }; +} + +export async function buildNativeStructuredRequest({ + model, + messages, + options = {}, + withAttachment = true, + attachmentCapability, + reasoning, + responseSchema, + middleware, +}: Omit< + BuildNativeRequestOptions, + 'tools' | 'include' +>): Promise { + const copiedMessages = messages.map(message => ({ + ...message, + attachments: message.attachments + ? [...message.attachments] + : message.attachments, + })); + + const systemMessage = + copiedMessages[0]?.role === 'system' ? copiedMessages.shift() : undefined; + const schema = resolveResponseSchema(systemMessage, responseSchema); + const strict = resolveResponseStrict(systemMessage, options); + + if (!schema) { + throw new CopilotPromptInvalid('Schema is required'); + } + + const coreMessages: NativeLlmCoreMessage[] = []; + if (systemMessage?.content?.length) { + coreMessages.push({ + role: 'system', + content: [{ type: 'text', text: systemMessage.content }], + }); + } + + for (const message of copiedMessages) { + if (message.role === 'system') continue; + const content = await toCoreContents( + message, + withAttachment, + attachmentCapability + ); + coreMessages.push({ role: roleToCore(message.role), content }); + } + + return { + request: { + model, + messages: coreMessages, + schema: ToolSchemaExtractor.toJsonSchema(schema), + max_tokens: options.maxTokens ?? undefined, + temperature: options.temperature ?? undefined, + reasoning, + strict, + response_mime_type: 'application/json', + middleware: middleware?.rust + ? { request: middleware.rust.request } + : undefined, + }, + schema, + }; +} + +export function buildNativeEmbeddingRequest({ + model, + inputs, + dimensions, + taskType = 'RETRIEVAL_DOCUMENT', +}: { + model: string; + inputs: string[]; + dimensions?: number; + taskType?: string; +}): NativeLlmEmbeddingRequest { + return { + model, + inputs, + dimensions, + task_type: taskType, + }; +} + +function ensureToolResultMeta( + event: Extract, + toolCalls: Map +): NormalizedToolResultEvent | null { + const name = event.name ?? toolCalls.get(event.call_id)?.name; + const args = event.arguments ?? toolCalls.get(event.call_id)?.args; + + if (!name || !args) return null; + return { ...event, name, arguments: args }; +} + +function pickAttachmentFootnote(value: unknown): AttachmentFootnote | null { + if (!value || typeof value !== 'object') { + return null; + } + + const record = value as Record; + const blobId = + typeof record.blobId === 'string' + ? record.blobId + : typeof record.blob_id === 'string' + ? record.blob_id + : undefined; + const fileName = + typeof record.fileName === 'string' + ? record.fileName + : typeof record.name === 'string' + ? record.name + : undefined; + const fileType = + typeof record.fileType === 'string' + ? record.fileType + : typeof record.mimeType === 'string' + ? record.mimeType + : 'application/octet-stream'; + + if (!blobId || !fileName) { + return null; + } + + return { blobId, fileName, fileType }; +} + +function collectAttachmentFootnotes( + event: NormalizedToolResultEvent +): AttachmentFootnote[] { + if (event.name === 'blob_read') { + const item = pickAttachmentFootnote(event.output); + return item ? [item] : []; + } + + if (event.name === 'doc_semantic_search' && Array.isArray(event.output)) { + return event.output + .map(item => pickAttachmentFootnote(item)) + .filter((item): item is AttachmentFootnote => item !== null); + } + + return []; +} + +function formatAttachmentFootnotes(attachments: AttachmentFootnote[]) { + const references = attachments.map((_, index) => `[^${index + 1}]`).join(''); + const definitions = attachments + .map((attachment, index) => { + return `[^${index + 1}]: ${JSON.stringify({ + type: 'attachment', + blobId: attachment.blobId, + fileName: attachment.fileName, + fileType: attachment.fileType, + })}`; + }) + .join('\n'); + + return `\n\n${references}\n\n${definitions}`; +} + +export class NativeProviderAdapter { + readonly #loop: ToolCallLoop; + readonly #enableCallout: boolean; + readonly #enableCitationFootnote: boolean; + + constructor( + dispatch: NativeDispatchFn, + tools: CopilotToolSet, + maxSteps = 20, + options: NativeProviderAdapterOptions = {} + ) { + this.#loop = new ToolCallLoop(dispatch, tools, maxSteps); + const enabledNodeTextMiddlewares = new Set( + options.nodeTextMiddleware ?? ['citation_footnote', 'callout'] + ); + this.#enableCallout = + enabledNodeTextMiddlewares.has('callout') || + enabledNodeTextMiddlewares.has('thinking_format'); + this.#enableCitationFootnote = + enabledNodeTextMiddlewares.has('citation_footnote'); + } + + async text( + request: NativeLlmRequest, + signal?: AbortSignal, + messages?: PromptMessage[] + ) { + let output = ''; + for await (const chunk of this.streamText(request, signal, messages)) { + output += chunk; + } + return output.trim(); + } + + async *streamText( + request: NativeLlmRequest, + signal?: AbortSignal, + messages?: PromptMessage[] + ): AsyncIterableIterator { + const textParser = this.#enableCallout ? new TextStreamParser() : null; + const citationFormatter = this.#enableCitationFootnote + ? new CitationFootnoteFormatter() + : null; + const toolCalls = new Map(); + let streamPartId = 0; + + for await (const event of this.#loop.run(request, signal, messages)) { + switch (event.type) { + case 'text_delta': { + if (textParser) { + yield textParser.parse({ + type: 'text-delta', + id: String(streamPartId++), + text: event.text, + }); + } else { + yield event.text; + } + break; + } + case 'reasoning_delta': { + if (textParser) { + yield textParser.parse({ + type: 'reasoning-delta', + id: String(streamPartId++), + text: event.text, + }); + } else { + yield event.text; + } + break; + } + case 'tool_call': { + const toolCall = { + name: event.name, + args: event.arguments, + }; + toolCalls.set(event.call_id, toolCall); + if (textParser) { + yield textParser.parse({ + type: 'tool-call', + toolCallId: event.call_id, + toolName: event.name as never, + input: event.arguments, + }); + } + break; + } + case 'tool_result': { + const normalized = ensureToolResultMeta(event, toolCalls); + if (!normalized || !textParser) { + break; + } + yield textParser.parse({ + type: 'tool-result', + toolCallId: normalized.call_id, + toolName: normalized.name as never, + input: normalized.arguments, + output: normalized.output, + }); + break; + } + case 'citation': { + if (citationFormatter) { + citationFormatter.consume({ + type: 'citation', + index: event.index, + url: event.url, + }); + } + break; + } + case 'done': { + const footnotes = textParser?.end() ?? ''; + const citations = citationFormatter?.end() ?? ''; + const tails = [citations, footnotes].filter(Boolean).join('\n'); + if (tails) { + yield `\n${tails}`; + } + break; + } + case 'error': { + throw new Error(event.message); + } + default: + break; + } + } + } + + async *streamObject( + request: NativeLlmRequest, + signal?: AbortSignal, + messages?: PromptMessage[] + ): AsyncIterableIterator { + const toolCalls = new Map(); + const citationFormatter = this.#enableCitationFootnote + ? new CitationFootnoteFormatter() + : null; + const fallbackAttachmentFootnotes = new Map(); + let hasFootnoteReference = false; + + for await (const event of this.#loop.run(request, signal, messages)) { + switch (event.type) { + case 'text_delta': { + if (event.text.includes('[^')) { + hasFootnoteReference = true; + } + yield { + type: 'text-delta', + textDelta: event.text, + }; + break; + } + case 'reasoning_delta': { + yield { + type: 'reasoning', + textDelta: event.text, + }; + break; + } + case 'tool_call': { + const toolCall = { + name: event.name, + args: event.arguments, + }; + toolCalls.set(event.call_id, toolCall); + yield { + type: 'tool-call', + toolCallId: event.call_id, + toolName: event.name, + args: event.arguments, + }; + break; + } + case 'tool_result': { + const normalized = ensureToolResultMeta(event, toolCalls); + if (!normalized) { + break; + } + const attachments = collectAttachmentFootnotes(normalized); + attachments.forEach(attachment => { + fallbackAttachmentFootnotes.set(attachment.blobId, attachment); + }); + yield { + type: 'tool-result', + toolCallId: normalized.call_id, + toolName: normalized.name, + args: normalized.arguments, + result: normalized.output, + }; + break; + } + case 'citation': { + if (citationFormatter) { + citationFormatter.consume({ + type: 'citation', + index: event.index, + url: event.url, + }); + } + break; + } + case 'done': { + const citations = citationFormatter?.end() ?? ''; + if (citations) { + hasFootnoteReference = true; + yield { + type: 'text-delta', + textDelta: `\n${citations}`, + }; + } + if (!hasFootnoteReference && fallbackAttachmentFootnotes.size > 0) { + yield { + type: 'text-delta', + textDelta: formatAttachmentFootnotes( + Array.from(fallbackAttachmentFootnotes.values()) + ), + }; + } + break; + } + case 'error': { + throw new Error(event.message); + } + default: + break; + } + } + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index 1b69e9037f..318d0845e4 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -1,56 +1,78 @@ -import { - createOpenAI, - openai, - type OpenAIProvider as VercelOpenAIProvider, - OpenAIResponsesProviderOptions, -} from '@ai-sdk/openai'; -import { - createOpenAICompatible, - type OpenAICompatibleProvider as VercelOpenAICompatibleProvider, -} from '@ai-sdk/openai-compatible'; -import { - AISDKError, - embedMany, - experimental_generateImage as generateImage, - generateObject, - generateText, - stepCountIs, - streamText, - Tool, -} from 'ai'; import { z } from 'zod'; import { CopilotPromptInvalid, - CopilotProviderNotSupported, CopilotProviderSideError, - fetchBuffer, metrics, OneMB, + readResponseBufferWithLimit, + safeFetch, UserFriendlyError, } from '../../../base'; +import { + llmDispatchStream, + llmEmbeddingDispatch, + llmRerankDispatch, + llmStructuredDispatch, + type NativeLlmBackendConfig, + type NativeLlmEmbeddingRequest, + type NativeLlmRequest, + type NativeLlmRerankRequest, + type NativeLlmRerankResponse, + type NativeLlmStructuredRequest, +} from '../../../native'; +import type { NodeTextMiddleware } from '../config'; +import type { CopilotTool, CopilotToolSet } from '../tools'; +import { IMAGE_ATTACHMENT_CAPABILITY } from './attachments'; +import { + buildNativeEmbeddingRequest, + buildNativeRequest, + buildNativeStructuredRequest, + NativeProviderAdapter, + parseNativeStructuredOutput, +} from './native'; import { CopilotProvider } from './provider'; import type { CopilotChatOptions, CopilotChatTools, CopilotEmbeddingOptions, CopilotImageOptions, - CopilotProviderModel, + CopilotRerankRequest, CopilotStructuredOptions, + ModelCapability, ModelConditions, PromptMessage, StreamObject, } from './types'; import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; -import { - chatToGPTMessage, - CitationParser, - StreamObjectParser, - TextStreamParser, -} from './utils'; +import { promptAttachmentToUrl } from './utils'; export const DEFAULT_DIMENSIONS = 256; +const GPT_5_SAMPLING_UNSUPPORTED_MODELS = /^(gpt-5(?:$|[.-]))/; + +export function normalizeOpenAIOptionsForModel< + T extends { + frequencyPenalty?: number | null; + presencePenalty?: number | null; + temperature?: number | null; + topP?: number | null; + }, +>(options: T, model: string): T { + if (!GPT_5_SAMPLING_UNSUPPORTED_MODELS.test(model)) { + return options; + } + + const normalizedOptions = { ...options }; + + delete normalizedOptions.frequencyPenalty; + delete normalizedOptions.presencePenalty; + delete normalizedOptions.temperature; + delete normalizedOptions.topP; + + return normalizedOptions; +} + export type OpenAIConfig = { apiKey: string; baseURL?: string; @@ -63,7 +85,12 @@ const ModelListSchema = z.object({ const ImageResponseSchema = z.union([ z.object({ - data: z.array(z.object({ b64_json: z.string() })), + data: z.array( + z.object({ + b64_json: z.string().optional(), + url: z.string().optional(), + }) + ), }), z.object({ error: z.object({ @@ -74,18 +101,65 @@ const ImageResponseSchema = z.union([ }), }), ]); -const LogProbsSchema = z.array( - z.object({ - token: z.string(), - logprob: z.number(), - top_logprobs: z.array( - z.object({ - token: z.string(), - logprob: z.number(), - }) - ), - }) -); +const TRUSTED_ATTACHMENT_HOST_SUFFIXES = ['cdn.affine.pro']; + +function normalizeImageFormatToMime(format?: string) { + switch (format?.toLowerCase()) { + case 'jpg': + case 'jpeg': + return 'image/jpeg'; + case 'webp': + return 'image/webp'; + case 'png': + return 'image/png'; + case 'gif': + return 'image/gif'; + default: + return 'image/png'; + } +} + +function normalizeImageResponseData( + data: { b64_json?: string; url?: string }[], + mimeType: string = 'image/png' +) { + return data + .map(image => { + if (image.b64_json) { + return `data:${mimeType};base64,${image.b64_json}`; + } + return image.url; + }) + .filter((value): value is string => typeof value === 'string'); +} + +function buildOpenAIRerankRequest( + model: string, + request: CopilotRerankRequest +): NativeLlmRerankRequest { + return { + model, + query: request.query, + candidates: request.candidates.map(candidate => ({ + ...(candidate.id ? { id: candidate.id } : {}), + text: candidate.text, + })), + ...(request.topK ? { top_n: request.topK } : {}), + }; +} + +function createOpenAIMultimodalCapability( + output: ModelCapability['output'], + options: Pick = {} +): ModelCapability { + return { + input: [ModelInputType.Text, ModelInputType.Image], + output, + attachments: IMAGE_ATTACHMENT_CAPABILITY, + structuredAttachments: IMAGE_ATTACHMENT_CAPABILITY, + ...options, + }; +} export class OpenAIProvider extends CopilotProvider { readonly type = CopilotProviderType.OpenAI; @@ -96,10 +170,10 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o', id: 'gpt-4o', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // FIXME(@darkskygit): deprecated @@ -107,20 +181,20 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o 2024-08-06', id: 'gpt-4o-2024-08-06', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT 4o Mini', id: 'gpt-4o-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // FIXME(@darkskygit): deprecated @@ -128,153 +202,158 @@ export class OpenAIProvider extends CopilotProvider { name: 'GPT 4o Mini 2024-07-18', id: 'gpt-4o-mini-2024-07-18', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT 4.1', id: 'gpt-4.1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ + createOpenAIMultimodalCapability( + [ ModelOutputType.Text, ModelOutputType.Object, + ModelOutputType.Rerank, ModelOutputType.Structured, ], - defaultForOutputType: true, - }, + { defaultForOutputType: true } + ), ], }, { name: 'GPT 4.1 2025-04-14', id: 'gpt-4.1-2025-04-14', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 4.1 Mini', id: 'gpt-4.1-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 4.1 Nano', id: 'gpt-4.1-nano', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5', id: 'gpt-5', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 2025-08-07', id: 'gpt-5-2025-08-07', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 Mini', id: 'gpt-5-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), + ], + }, + { + name: 'GPT 5.2', + id: 'gpt-5.2', + capabilities: [ + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Rerank, + ModelOutputType.Structured, + ]), + ], + }, + { + name: 'GPT 5.2 2025-12-11', + id: 'gpt-5.2-2025-12-11', + capabilities: [ + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT 5 Nano', id: 'gpt-5-nano', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ - ModelOutputType.Text, - ModelOutputType.Object, - ModelOutputType.Structured, - ], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ModelOutputType.Structured, + ]), ], }, { name: 'GPT O1', id: 'o1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT O3', id: 'o3', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, { name: 'GPT O4 Mini', id: 'o4-mini', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text, ModelOutputType.Object], - }, + createOpenAIMultimodalCapability([ + ModelOutputType.Text, + ModelOutputType.Object, + ]), ], }, // Embedding models @@ -310,62 +389,30 @@ export class OpenAIProvider extends CopilotProvider { { id: 'gpt-image-1', capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Image], + createOpenAIMultimodalCapability([ModelOutputType.Image], { defaultForOutputType: true, - }, + }), ], }, ]; - #instance!: VercelOpenAIProvider | VercelOpenAICompatibleProvider; - override configured(): boolean { return !!this.config.apiKey; } protected override setup() { super.setup(); - this.#instance = - this.config.oldApiStyle && this.config.baseURL - ? createOpenAICompatible({ - name: 'openai-compatible-old-style', - apiKey: this.config.apiKey, - baseURL: this.config.baseURL, - }) - : createOpenAI({ - apiKey: this.config.apiKey, - baseURL: this.config.baseURL, - }); } - private handleError( - e: any, - model: string, - options: CopilotImageOptions = {} - ) { + private handleError(e: any) { if (e instanceof UserFriendlyError) { return e; - } else if (e instanceof AISDKError) { - if (e.message.includes('safety') || e.message.includes('risk')) { - metrics.ai - .counter('chat_text_risk_errors') - .add(1, { model, user: options.user || undefined }); - } - - return new CopilotProviderSideError({ - provider: this.type, - kind: e.name || 'unknown', - message: e.message, - }); - } else { - return new CopilotProviderSideError({ - provider: this.type, - kind: 'unexpected_response', - message: e?.message || 'Unexpected openai response', - }); } + return new CopilotProviderSideError({ + provider: this.type, + kind: 'unexpected_response', + message: e?.message || 'Unexpected openai response', + }); } override async refreshOnlineModels() { @@ -389,57 +436,114 @@ export class OpenAIProvider extends CopilotProvider { override getProviderSpecificTools( toolName: CopilotChatTools, - model: string - ): [string, Tool?] | undefined { - if ( - toolName === 'webSearch' && - 'responses' in this.#instance && - !this.isReasoningModel(model) - ) { - return ['web_search_preview', openai.tools.webSearch({})]; - } else if (toolName === 'docEdit') { + _model: string + ): [string, CopilotTool?] | undefined { + if (toolName === 'docEdit') { return ['doc_edit', undefined]; } return; } + private createNativeConfig(): NativeLlmBackendConfig { + const baseUrl = this.config.baseURL || 'https://api.openai.com/v1'; + return { + base_url: baseUrl.replace(/\/v1\/?$/, ''), + auth_token: this.config.apiKey, + }; + } + + private getNativeProtocol() { + return this.config.oldApiStyle ? 'openai_chat' : 'openai_responses'; + } + + private createNativeAdapter( + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream( + this.getNativeProtocol(), + this.createNativeConfig(), + request, + signal + ), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); + } + + protected createNativeStructuredDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmStructuredRequest) => + llmStructuredDispatch(this.getNativeProtocol(), backendConfig, request); + } + + protected createNativeEmbeddingDispatch( + backendConfig: NativeLlmBackendConfig + ) { + return (request: NativeLlmEmbeddingRequest) => + llmEmbeddingDispatch(this.getNativeProtocol(), backendConfig, request); + } + + protected createNativeRerankDispatch(backendConfig: NativeLlmBackendConfig) { + return ( + request: NativeLlmRerankRequest + ): Promise => + llmRerankDispatch('openai_chat', backendConfig, request); + } + + private getReasoning( + options: NonNullable, + model: string + ): Record | undefined { + if (options.reasoning && this.isReasoningModel(model)) { + return { effort: 'medium' }; + } + return undefined; + } + async text( cond: ModelConditions, messages: PromptMessage[], options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs] = await chatToGPTMessage(messages); - - const modelInstance = - 'responses' in this.#instance - ? this.#instance.responses(model.id) - : this.#instance(model.id); - - const { text } = await generateText({ - model: modelInstance, - system, - messages: msgs, - temperature: options.temperature ?? 0, - maxOutputTokens: options.maxTokens ?? 4096, - providerOptions: { - openai: this.getOpenAIOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), - abortSignal: options.signal, + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const normalizedOptions = normalizeOpenAIOptionsForModel( + options, + model.id + ); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options: normalizedOptions, + tools, + attachmentCapability: cap, + include: options.webSearch ? ['citations'] : undefined, + reasoning: this.getReasoning(options, model.id), + middleware, }); - - return text.trim(); + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); + throw this.handleError(e); } } @@ -452,42 +556,47 @@ export class OpenAIProvider extends CopilotProvider { ...cond, outputType: ModelOutputType.Text, }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const citationParser = new CitationParser(); - const textParser = new TextStreamParser(); - for await (const chunk of fullStream) { - switch (chunk.type) { - case 'text-delta': { - let result = textParser.parse(chunk); - result = citationParser.parse(result); - yield result; - break; - } - case 'finish': { - const footnotes = textParser.end(); - const result = - citationParser.end() + (footnotes.length ? '\n' + footnotes : ''); - yield result; - break; - } - default: { - yield textParser.parse(chunk); - break; - } - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Text); + const normalizedOptions = normalizeOpenAIOptionsForModel( + options, + model.id + ); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options: normalizedOptions, + tools, + attachmentCapability: cap, + include: options.webSearch ? ['citations'] : undefined, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); + throw this.handleError(e); } } @@ -497,30 +606,47 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Object }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('chat_object_stream_calls') - .add(1, { model: model.id }); - const fullStream = await this.getFullStream(model, messages, options); - const parser = new StreamObjectParser(); - for await (const chunk of fullStream) { - const result = parser.parse(chunk); - if (result) { - yield result; - } - if (options.signal?.aborted) { - await fullStream.cancel(); - break; - } + .add(1, this.metricLabels(model.id)); + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Object); + const normalizedOptions = normalizeOpenAIOptionsForModel( + options, + model.id + ); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options: normalizedOptions, + tools, + attachmentCapability: cap, + include: options.webSearch ? ['citations'] : undefined, + reasoning: this.getReasoning(options, model.id), + middleware, + }); + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + for await (const chunk of adapter.streamObject( + request, + options.signal, + messages + )) { + yield chunk; } } catch (e: any) { metrics.ai .counter('chat_object_stream_errors') - .add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + .add(1, this.metricLabels(model.id)); + throw this.handleError(e); } } @@ -530,154 +656,273 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotStructuredOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Structured }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); - - const [system, msgs, schema] = await chatToGPTMessage(messages); - if (!schema) { - throw new CopilotPromptInvalid('Schema is required'); - } - - const modelInstance = - 'responses' in this.#instance - ? this.#instance.responses(model.id) - : this.#instance(model.id); - - const { object } = await generateObject({ - model: modelInstance, - system, - messages: msgs, - temperature: options.temperature ?? 0, - maxOutputTokens: options.maxTokens ?? 4096, - maxRetries: options.maxRetries ?? 3, - schema, - providerOptions: { - openai: options.user ? { user: options.user } : {}, - }, - abortSignal: options.signal, + const backendConfig = this.createNativeConfig(); + const middleware = this.getActiveProviderMiddleware(); + const cap = this.getAttachCapability(model, ModelOutputType.Structured); + const normalizedOptions = normalizeOpenAIOptionsForModel( + options, + model.id + ); + const { request, schema } = await buildNativeStructuredRequest({ + model: model.id, + messages, + options: normalizedOptions, + attachmentCapability: cap, + reasoning: this.getReasoning(options, model.id), + responseSchema: options.schema, + middleware, }); - - return JSON.stringify(object); + const response = + await this.createNativeStructuredDispatch(backendConfig)(request); + const parsed = parseNativeStructuredOutput(response); + const validated = schema.parse(parsed); + return JSON.stringify(validated); } catch (e: any) { metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + throw this.handleError(e); } } override async rerank( cond: ModelConditions, - chunkMessages: PromptMessage[][], + request: CopilotRerankRequest, options: CopilotChatOptions = {} ): Promise { - const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ messages: [], cond: fullCond, options }); - const model = this.selectModel(fullCond); - // get the log probability of "yes"/"no" - const instance = - 'chat' in this.#instance - ? this.#instance.chat(model.id) - : this.#instance(model.id); - - const scores = await Promise.all( - chunkMessages.map(async messages => { - const [system, msgs] = await chatToGPTMessage(messages); - - const result = await generateText({ - model: instance, - system, - messages: msgs, - temperature: 0, - maxOutputTokens: 16, - providerOptions: { - openai: { - ...this.getOpenAIOptions(options, model.id), - logprobs: 16, - }, - }, - abortSignal: options.signal, - }); - - const topMap: Record = LogProbsSchema.parse( - result.providerMetadata?.openai?.logprobs - )[0].top_logprobs.reduce>( - (acc, { token, logprob }) => ({ ...acc, [token]: logprob }), - {} - ); - - const findLogProb = (token: string): number => { - // OpenAI often includes a leading space, so try matching '.yes', '_yes', ' yes' and 'yes' - return [...'_:. "-\t,(=_“'.split('').map(c => c + token), token] - .flatMap(v => [v, v.toLowerCase(), v.toUpperCase()]) - .reduce( - (best, key) => - (topMap[key] ?? Number.NEGATIVE_INFINITY) > best - ? topMap[key] - : best, - Number.NEGATIVE_INFINITY - ); - }; - - const logYes = findLogProb('Yes'); - const logNo = findLogProb('No'); - - const pYes = Math.exp(logYes); - const pNo = Math.exp(logNo); - const prob = pYes + pNo === 0 ? 0 : pYes / (pYes + pNo); - - return prob; - }) - ); - - return scores; - } - - private async getFullStream( - model: CopilotProviderModel, - messages: PromptMessage[], - options: CopilotChatOptions = {} - ) { - const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = - 'responses' in this.#instance - ? this.#instance.responses(model.id) - : this.#instance(model.id); - const { fullStream } = streamText({ - model: modelInstance, - system, - messages: msgs, - frequencyPenalty: options.frequencyPenalty ?? 0, - presencePenalty: options.presencePenalty ?? 0, - temperature: options.temperature ?? 0, - maxOutputTokens: options.maxTokens ?? 4096, - providerOptions: { - openai: this.getOpenAIOptions(options, model.id), - }, - tools: await this.getTools(options, model.id), - stopWhen: stepCountIs(this.MAX_STEPS), - abortSignal: options.signal, + const fullCond = { ...cond, outputType: ModelOutputType.Rerank }; + const normalizedCond = await this.checkParams({ + messages: [], + cond: fullCond, + options, }); - return fullStream; + const model = this.selectModel(normalizedCond); + + try { + const backendConfig = this.createNativeConfig(); + const nativeRequest = buildOpenAIRerankRequest(model.id, request); + const response = + await this.createNativeRerankDispatch(backendConfig)(nativeRequest); + return response.scores; + } catch (e: any) { + throw this.handleError(e); + } } // ====== text to image ====== + private buildImageFetchOptions(url: URL) { + const baseOptions = { timeoutMs: 15_000, maxRedirects: 3 } as const; + const trustedOrigins = new Set(); + const protocol = this.AFFiNEConfig.server.https ? 'https:' : 'http:'; + const port = this.AFFiNEConfig.server.port; + const isDefaultPort = + (protocol === 'https:' && port === 443) || + (protocol === 'http:' && port === 80); + + const addHostOrigin = (host: string) => { + if (!host) return; + try { + const parsed = new URL(`${protocol}//${host}`); + if (!parsed.port && !isDefaultPort) { + parsed.port = String(port); + } + trustedOrigins.add(parsed.origin); + } catch { + // ignore invalid host config entries + } + }; + + if (this.AFFiNEConfig.server.externalUrl) { + try { + trustedOrigins.add( + new URL(this.AFFiNEConfig.server.externalUrl).origin + ); + } catch { + // ignore invalid external URL + } + } + + addHostOrigin(this.AFFiNEConfig.server.host); + for (const host of this.AFFiNEConfig.server.hosts) { + addHostOrigin(host); + } + + const hostname = url.hostname.toLowerCase(); + const trustedByHost = TRUSTED_ATTACHMENT_HOST_SUFFIXES.some( + suffix => hostname === suffix || hostname.endsWith(`.${suffix}`) + ); + if (trustedOrigins.has(url.origin) || trustedByHost) { + return { ...baseOptions, allowPrivateOrigins: new Set([url.origin]) }; + } + + return baseOptions; + } + + private redactUrl(raw: string | URL): string { + try { + const parsed = raw instanceof URL ? raw : new URL(raw); + if (parsed.protocol === 'data:') return 'data:[redacted]'; + const segments = parsed.pathname.split('/').filter(Boolean); + const redactedPath = + segments.length <= 2 + ? parsed.pathname || '/' + : `/${segments[0]}/${segments[1]}/...`; + return `${parsed.origin}${redactedPath}`; + } catch { + return '[invalid-url]'; + } + } + + private async fetchImage( + url: string, + maxBytes: number, + signal?: AbortSignal + ): Promise<{ buffer: Buffer; type: string } | null> { + if (url.startsWith('data:')) { + let response: Response; + try { + response = await fetch(url, { signal }); + } catch (error) { + this.logger.warn( + `Skip image attachment data URL due to read failure: ${ + error instanceof Error ? error.message : String(error) + }` + ); + return null; + } + + if (!response.ok) { + this.logger.warn( + `Skip image attachment data URL due to invalid response: ${response.status}` + ); + return null; + } + + const type = + response.headers.get('content-type') || 'application/octet-stream'; + if (!type.startsWith('image/')) { + await response.body?.cancel().catch(() => undefined); + this.logger.warn( + `Skip non-image attachment data URL with content-type ${type}` + ); + return null; + } + + try { + const buffer = await readResponseBufferWithLimit(response, maxBytes); + return { buffer, type }; + } catch (error) { + this.logger.warn( + `Skip image attachment data URL due to read failure/size limit: ${ + error instanceof Error ? error.message : String(error) + }` + ); + return null; + } + } + + let parsed: URL; + try { + parsed = new URL(url); + } catch { + this.logger.warn( + `Skip image attachment with invalid URL: ${this.redactUrl(url)}` + ); + return null; + } + const redactedUrl = this.redactUrl(parsed); + + if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') { + this.logger.warn( + `Skip image attachment with unsupported protocol: ${redactedUrl}` + ); + return null; + } + + let response: Response; + try { + response = await safeFetch( + parsed, + { method: 'GET', signal }, + this.buildImageFetchOptions(parsed) + ); + } catch (error) { + this.logger.warn( + `Skip image attachment due to blocked/unreachable URL: ${redactedUrl}, reason: ${ + error instanceof Error ? error.message : String(error) + }` + ); + return null; + } + + if (!response.ok) { + this.logger.warn( + `Skip image attachment fetch failure ${response.status}: ${redactedUrl}` + ); + return null; + } + + const type = + response.headers.get('content-type') || 'application/octet-stream'; + if (!type.startsWith('image/')) { + await response.body?.cancel().catch(() => undefined); + this.logger.warn( + `Skip non-image attachment with content-type ${type}: ${redactedUrl}` + ); + return null; + } + + const contentLength = Number(response.headers.get('content-length')); + if (Number.isFinite(contentLength) && contentLength > maxBytes) { + await response.body?.cancel().catch(() => undefined); + this.logger.warn( + `Skip oversized image attachment by content-length (${contentLength}): ${redactedUrl}` + ); + return null; + } + + try { + const buffer = await readResponseBufferWithLimit(response, maxBytes); + return { buffer, type }; + } catch (error) { + this.logger.warn( + `Skip image attachment due to read failure/size limit: ${redactedUrl}, reason: ${ + error instanceof Error ? error.message : String(error) + }` + ); + return null; + } + } + private async *generateImageWithAttachments( model: string, prompt: string, - attachments: NonNullable + attachments: NonNullable, + signal?: AbortSignal ): AsyncGenerator { const form = new FormData(); + const outputFormat = 'webp'; + const maxBytes = 10 * OneMB; form.set('model', model); form.set('prompt', prompt); - form.set('output_format', 'webp'); + form.set('output_format', outputFormat); for (const [idx, entry] of attachments.entries()) { - const url = typeof entry === 'string' ? entry : entry.attachment; + const url = promptAttachmentToUrl(entry); + if (!url) continue; try { - const { buffer, type } = await fetchBuffer(url, 10 * OneMB, 'image/'); - const file = new File([buffer], `${idx}.png`, { type }); + const attachment = await this.fetchImage(url, maxBytes, signal); + if (!attachment) continue; + const { buffer, type } = attachment; + const extension = type.split(';')[0].split('/')[1] || 'png'; + const file = new File([buffer], `${idx}.${extension}`, { type }); form.append('image[]', file); } catch { continue; @@ -703,18 +948,24 @@ export class OpenAIProvider extends CopilotProvider { const json = await res.json(); const imageResponse = ImageResponseSchema.safeParse(json); - if (imageResponse.success) { - const data = imageResponse.data; - if ('error' in data) { - throw new Error(data.error.message); - } else { - for (const image of data.data) { - yield `data:image/webp;base64,${image.b64_json}`; - } - } - } else { + if (!imageResponse.success) { throw new Error(imageResponse.error.message); } + const data = imageResponse.data; + if ('error' in data) { + throw new Error(data.error.message); + } + + const images = normalizeImageResponseData( + data.data, + normalizeImageFormatToMime(outputFormat) + ); + if (!images.length) { + throw new Error('No images returned from OpenAI'); + } + for (const image of images) { + yield image; + } } override async *streamImages( @@ -723,15 +974,12 @@ export class OpenAIProvider extends CopilotProvider { options: CopilotImageOptions = {} ) { const fullCond = { ...cond, outputType: ModelOutputType.Image }; - await this.checkParams({ messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); - - if (!('image' in this.#instance)) { - throw new CopilotProviderNotSupported({ - provider: this.type, - kind: 'image', - }); - } + const normalizedCond = await this.checkParams({ + messages, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); metrics.ai .counter('generate_images_stream_calls') @@ -742,22 +990,27 @@ export class OpenAIProvider extends CopilotProvider { try { if (attachments && attachments.length > 0) { - yield* this.generateImageWithAttachments(model.id, prompt, attachments); - } else { - const modelInstance = this.#instance.image(model.id); - const result = await generateImage({ - model: modelInstance, + yield* this.generateImageWithAttachments( + model.id, prompt, - providerOptions: { - openai: { - quality: options.quality || null, - }, - }, - }); - - const imageUrls = result.images.map( - image => `data:image/png;base64,${image.base64}` + attachments, + options.signal ); + } else { + const response = await this.requestOpenAIJson('/images/generations', { + model: model.id, + prompt, + ...(options.quality ? { quality: options.quality } : {}), + }); + const imageResponse = ImageResponseSchema.parse(response); + if ('error' in imageResponse) { + throw new Error(imageResponse.error.message); + } + + const imageUrls = normalizeImageResponseData(imageResponse.data); + if (!imageUrls.length) { + throw new Error('No images returned from OpenAI'); + } for (const imageUrl of imageUrls) { yield imageUrl; @@ -769,7 +1022,7 @@ export class OpenAIProvider extends CopilotProvider { return; } catch (e: any) { metrics.ai.counter('generate_images_errors').add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + throw this.handleError(e); } } @@ -778,54 +1031,59 @@ export class OpenAIProvider extends CopilotProvider { messages: string | string[], options: CopilotEmbeddingOptions = { dimensions: DEFAULT_DIMENSIONS } ): Promise { - messages = Array.isArray(messages) ? messages : [messages]; + const input = Array.isArray(messages) ? messages : [messages]; const fullCond = { ...cond, outputType: ModelOutputType.Embedding }; - await this.checkParams({ embeddings: messages, cond: fullCond, options }); - const model = this.selectModel(fullCond); - - if (!('embedding' in this.#instance)) { - throw new CopilotProviderNotSupported({ - provider: this.type, - kind: 'embedding', - }); - } + const normalizedCond = await this.checkParams({ + embeddings: input, + cond: fullCond, + options, + }); + const model = this.selectModel(normalizedCond); try { metrics.ai .counter('generate_embedding_calls') - .add(1, { model: model.id }); - - const modelInstance = this.#instance.embedding(model.id); - - const { embeddings } = await embedMany({ - model: modelInstance, - values: messages, - providerOptions: { - openai: { - dimensions: options.dimensions || DEFAULT_DIMENSIONS, - }, - }, - }); - - return embeddings.filter(v => v && Array.isArray(v)); + .add(1, this.metricLabels(model.id)); + const backendConfig = this.createNativeConfig(); + const response = await this.createNativeEmbeddingDispatch(backendConfig)( + buildNativeEmbeddingRequest({ + model: model.id, + inputs: input, + dimensions: options.dimensions || DEFAULT_DIMENSIONS, + }) + ); + return response.embeddings; } catch (e: any) { metrics.ai .counter('generate_embedding_errors') - .add(1, { model: model.id }); - throw this.handleError(e, model.id, options); + .add(1, this.metricLabels(model.id)); + throw this.handleError(e); } } - private getOpenAIOptions(options: CopilotChatOptions, model: string) { - const result: OpenAIResponsesProviderOptions = {}; - if (options?.reasoning && this.isReasoningModel(model)) { - result.reasoningEffort = 'medium'; - result.reasoningSummary = 'detailed'; + private async requestOpenAIJson( + path: string, + body: Record, + signal?: AbortSignal + ): Promise { + const baseUrl = this.config.baseURL || 'https://api.openai.com/v1'; + const response = await fetch(`${baseUrl}${path}`, { + method: 'POST', + headers: { + Authorization: `Bearer ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify(body), + signal, + }); + + if (!response.ok) { + throw new Error( + `OpenAI API error ${response.status}: ${await response.text()}` + ); } - if (options?.user) { - result.user = options.user; - } - return result; + + return await response.json(); } private isReasoningModel(model: string) { diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts index b49f7ece1b..0bfd030779 100644 --- a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -1,11 +1,12 @@ -import { - createPerplexity, - type PerplexityProvider as VercelPerplexityProvider, -} from '@ai-sdk/perplexity'; -import { generateText, streamText } from 'ai'; -import { z } from 'zod'; - import { CopilotProviderSideError, metrics } from '../../../base'; +import { + llmDispatchStream, + type NativeLlmBackendConfig, + type NativeLlmRequest, +} from '../../../native'; +import type { NodeTextMiddleware } from '../config'; +import type { CopilotToolSet } from '../tools'; +import { buildNativeRequest, NativeProviderAdapter } from './native'; import { CopilotProvider } from './provider'; import { CopilotChatOptions, @@ -15,34 +16,12 @@ import { ModelOutputType, PromptMessage, } from './types'; -import { chatToGPTMessage, CitationParser } from './utils'; export type PerplexityConfig = { apiKey: string; endpoint?: string; }; -const PerplexityErrorSchema = z.union([ - z.object({ - detail: z.array( - z.object({ - loc: z.array(z.string()), - msg: z.string(), - type: z.string(), - }) - ), - }), - z.object({ - error: z.object({ - message: z.string(), - type: z.string(), - code: z.number(), - }), - }), -]); - -type PerplexityError = z.infer; - export class PerplexityProvider extends CopilotProvider { readonly type = CopilotProviderType.Perplexity; @@ -90,18 +69,38 @@ export class PerplexityProvider extends CopilotProvider { }, ]; - #instance!: VercelPerplexityProvider; - override configured(): boolean { return !!this.config.apiKey; } protected override setup() { super.setup(); - this.#instance = createPerplexity({ - apiKey: this.config.apiKey, - baseURL: this.config.endpoint, - }); + } + + private createNativeConfig(): NativeLlmBackendConfig { + const baseUrl = this.config.endpoint || 'https://api.perplexity.ai'; + return { + base_url: baseUrl.replace(/\/v1\/?$/, ''), + auth_token: this.config.apiKey, + }; + } + + private createNativeAdapter( + tools: CopilotToolSet, + nodeTextMiddleware?: NodeTextMiddleware[] + ) { + return new NativeProviderAdapter( + (request: NativeLlmRequest, signal?: AbortSignal) => + llmDispatchStream( + 'openai_chat', + this.createNativeConfig(), + request, + signal + ), + tools, + this.MAX_STEPS, + { nodeTextMiddleware } + ); } async text( @@ -110,36 +109,34 @@ export class PerplexityProvider extends CopilotProvider { options: CopilotChatOptions = {} ): Promise { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + withAttachment: false, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_calls').add(1, { model: model.id }); + metrics.ai.counter('chat_text_calls').add(1, this.metricLabels(model.id)); - const [system, msgs] = await chatToGPTMessage(messages, false); - - const modelInstance = this.#instance(model.id); - - const { text, sources } = await generateText({ - model: modelInstance, - system, - messages: msgs, - temperature: options.temperature ?? 0, - maxOutputTokens: options.maxTokens ?? 4096, - abortSignal: options.signal, + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + withAttachment: false, + include: ['citations'], + middleware, }); - - const parser = new CitationParser(); - for (const source of sources.filter(s => s.sourceType === 'url')) { - parser.push(source.url); - } - - let result = text.replaceAll(/<\/?think>\n/g, '\n---\n'); - result = parser.parse(result); - result += parser.end(); - return result; + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + return await adapter.text(request, options.signal, messages); } catch (e: any) { - metrics.ai.counter('chat_text_errors').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_errors') + .add(1, this.metricLabels(model.id)); throw this.handleError(e); } } @@ -150,83 +147,46 @@ export class PerplexityProvider extends CopilotProvider { options: CopilotChatOptions = {} ): AsyncIterable { const fullCond = { ...cond, outputType: ModelOutputType.Text }; - await this.checkParams({ cond: fullCond, messages, options }); - const model = this.selectModel(fullCond); + const normalizedCond = await this.checkParams({ + cond: fullCond, + messages, + options, + withAttachment: false, + }); + const model = this.selectModel(normalizedCond); try { - metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); + metrics.ai + .counter('chat_text_stream_calls') + .add(1, this.metricLabels(model.id)); - const [system, msgs] = await chatToGPTMessage(messages, false); - - const modelInstance = this.#instance(model.id); - - const stream = streamText({ - model: modelInstance, - system, - messages: msgs, - temperature: options.temperature ?? 0, - maxOutputTokens: options.maxTokens ?? 4096, - abortSignal: options.signal, + const tools = await this.getTools(options, model.id); + const middleware = this.getActiveProviderMiddleware(); + const { request } = await buildNativeRequest({ + model: model.id, + messages, + options, + tools, + withAttachment: false, + include: ['citations'], + middleware, }); - - const parser = new CitationParser(); - for await (const chunk of stream.fullStream) { - switch (chunk.type) { - case 'source': { - if (chunk.sourceType === 'url') { - parser.push(chunk.url); - } - break; - } - case 'text-delta': { - const text = chunk.text.replaceAll(/<\/?think>\n?/g, '\n---\n'); - const result = parser.parse(text); - yield result; - break; - } - case 'finish-step': { - const result = parser.end(); - yield result; - break; - } - case 'error': { - const json = - typeof chunk.error === 'string' - ? JSON.parse(chunk.error) - : chunk.error; - if (json && typeof json === 'object') { - const data = PerplexityErrorSchema.parse(json); - if ('detail' in data || 'error' in data) { - throw this.convertError(data); - } - } - } - } + const adapter = this.createNativeAdapter(tools, middleware.node?.text); + for await (const chunk of adapter.streamText( + request, + options.signal, + messages + )) { + yield chunk; } - } catch (e) { - metrics.ai.counter('chat_text_stream_errors').add(1, { model: model.id }); - throw e; + } catch (e: any) { + metrics.ai + .counter('chat_text_stream_errors') + .add(1, this.metricLabels(model.id)); + throw this.handleError(e); } } - private convertError(e: PerplexityError) { - function getErrMessage(e: PerplexityError) { - let err = 'Unexpected perplexity response'; - if ('detail' in e) { - err = e.detail[0].msg || err; - } else if ('error' in e) { - err = e.error.message || err; - } - return err; - } - - throw new CopilotProviderSideError({ - provider: this.type, - kind: 'unexpected_response', - message: getErrMessage(e), - }); - } - private handleError(e: any) { if (e instanceof CopilotProviderSideError) { return e; diff --git a/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts b/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts new file mode 100644 index 0000000000..d43d30c5bc --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/provider-middleware.ts @@ -0,0 +1,106 @@ +import type { ProviderMiddlewareConfig } from '../config'; +import { CopilotProviderType } from './types'; + +const DEFAULT_MIDDLEWARE_BY_TYPE: Record< + CopilotProviderType, + ProviderMiddlewareConfig +> = { + [CopilotProviderType.OpenAI]: { + rust: { + request: ['normalize_messages'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.Anthropic]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.AnthropicVertex]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.Morph]: { + rust: { + request: ['clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.Perplexity]: { + rust: { + request: ['clamp_max_tokens'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.Gemini]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.GeminiVertex]: { + rust: { + request: ['normalize_messages', 'tool_schema_rewrite'], + stream: ['stream_event_normalize', 'citation_indexing'], + }, + node: { + text: ['citation_footnote', 'callout'], + }, + }, + [CopilotProviderType.FAL]: {}, +}; + +function unique(items: T[]) { + return [...new Set(items)]; +} + +function mergeArray(base: T[] | undefined, override: T[] | undefined) { + if (!base?.length && !override?.length) { + return undefined; + } + return unique([...(base ?? []), ...(override ?? [])]); +} + +export function mergeProviderMiddleware( + defaults: ProviderMiddlewareConfig, + override?: ProviderMiddlewareConfig +): ProviderMiddlewareConfig { + return { + rust: { + request: mergeArray(defaults.rust?.request, override?.rust?.request), + stream: mergeArray(defaults.rust?.stream, override?.rust?.stream), + }, + node: { + text: mergeArray(defaults.node?.text, override?.node?.text), + }, + }; +} + +export function resolveProviderMiddleware( + type: CopilotProviderType, + override?: ProviderMiddlewareConfig +): ProviderMiddlewareConfig { + const defaults = DEFAULT_MIDDLEWARE_BY_TYPE[type] ?? {}; + return mergeProviderMiddleware(defaults, override); +} diff --git a/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts b/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts new file mode 100644 index 0000000000..73d1bbdc8b --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/provider-registry.ts @@ -0,0 +1,278 @@ +import type { + CopilotProviderConfigMap, + CopilotProviderDefaults, + CopilotProviderProfile, + ProviderMiddlewareConfig, +} from '../config'; +import { resolveProviderMiddleware } from './provider-middleware'; +import { CopilotProviderType, ModelOutputType } from './types'; + +const PROVIDER_ID_PATTERN = /^[a-zA-Z0-9-_]+$/; + +const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [ + CopilotProviderType.OpenAI, + CopilotProviderType.FAL, + CopilotProviderType.Gemini, + CopilotProviderType.GeminiVertex, + CopilotProviderType.Perplexity, + CopilotProviderType.Anthropic, + CopilotProviderType.AnthropicVertex, + CopilotProviderType.Morph, +]; + +const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce( + (acc, type, index) => { + acc[type] = LEGACY_PROVIDER_ORDER.length - index; + return acc; + }, + {} as Record +); + +type LegacyProvidersConfig = Partial< + Record +>; + +export type CopilotProvidersConfigInput = LegacyProvidersConfig & { + profiles?: CopilotProviderProfile[] | null; + defaults?: CopilotProviderDefaults | null; +}; + +export type NormalizedCopilotProviderProfile = Omit< + CopilotProviderProfile, + 'enabled' | 'priority' | 'middleware' +> & { + enabled: boolean; + priority: number; + middleware: ProviderMiddlewareConfig; +}; + +export type CopilotProviderRegistry = { + profiles: Map; + defaults: CopilotProviderDefaults; + order: string[]; + byType: Map; +}; + +export type ResolveModelResult = { + rawModelId?: string; + modelId?: string; + explicitProviderId?: string; + candidateProviderIds: string[]; +}; + +type ResolveModelOptions = { + registry: CopilotProviderRegistry; + modelId?: string; + outputType?: ModelOutputType; + availableProviderIds?: Iterable; + preferredProviderIds?: Iterable; +}; + +function unique(list: T[]): T[] { + return [...new Set(list)]; +} + +function asArray(iter?: Iterable): T[] { + return iter ? Array.from(iter) : []; +} + +function parseModelPrefix( + registry: CopilotProviderRegistry, + modelId: string +): { providerId: string; modelId?: string } | null { + const index = modelId.indexOf('/'); + if (index <= 0) { + return null; + } + + const providerId = modelId.slice(0, index); + if (!registry.profiles.has(providerId)) { + return null; + } + + const model = modelId.slice(index + 1); + return { providerId, modelId: model || undefined }; +} + +function normalizeProfile( + profile: CopilotProviderProfile +): NormalizedCopilotProviderProfile { + return { + ...profile, + enabled: profile.enabled !== false, + priority: profile.priority ?? 0, + middleware: resolveProviderMiddleware(profile.type, profile.middleware), + }; +} + +function toLegacyProfiles( + config: CopilotProvidersConfigInput +): CopilotProviderProfile[] { + const legacyProfiles: CopilotProviderProfile[] = []; + for (const type of LEGACY_PROVIDER_ORDER) { + const legacyConfig = config[type]; + if (!legacyConfig) { + continue; + } + legacyProfiles.push({ + id: `${type}-default`, + type, + priority: LEGACY_PROVIDER_PRIORITY[type], + config: legacyConfig, + } as CopilotProviderProfile); + } + return legacyProfiles; +} + +function mergeProfiles( + explicitProfiles: CopilotProviderProfile[], + legacyProfiles: CopilotProviderProfile[] +): CopilotProviderProfile[] { + const profiles = new Map(); + + for (const profile of explicitProfiles) { + if (!PROVIDER_ID_PATTERN.test(profile.id)) { + throw new Error(`Invalid copilot provider profile id: ${profile.id}`); + } + if (profiles.has(profile.id)) { + throw new Error(`Duplicated copilot provider profile id: ${profile.id}`); + } + profiles.set(profile.id, profile); + } + + for (const profile of legacyProfiles) { + if (!profiles.has(profile.id)) { + profiles.set(profile.id, profile); + } + } + + return Array.from(profiles.values()); +} + +function sortProfiles(profiles: NormalizedCopilotProviderProfile[]) { + return profiles.toSorted((a, b) => { + if (a.priority !== b.priority) { + return b.priority - a.priority; + } + return a.id.localeCompare(b.id); + }); +} + +function assertDefaults( + defaults: CopilotProviderDefaults, + profiles: Map +) { + for (const providerId of Object.values(defaults)) { + if (!providerId) { + continue; + } + if (!profiles.has(providerId)) { + throw new Error( + `Copilot provider defaults references unknown providerId: ${providerId}` + ); + } + } +} + +export function buildProviderRegistry( + config: CopilotProvidersConfigInput +): CopilotProviderRegistry { + const explicitProfiles = config.profiles ?? []; + const legacyProfiles = toLegacyProfiles(config); + const mergedProfiles = mergeProfiles(explicitProfiles, legacyProfiles) + .map(normalizeProfile) + .filter(profile => profile.enabled); + const sortedProfiles = sortProfiles(mergedProfiles); + + const profiles = new Map( + sortedProfiles.map(profile => [profile.id, profile] as const) + ); + const defaults = config.defaults ?? {}; + assertDefaults(defaults, profiles); + + const order = sortedProfiles.map(profile => profile.id); + const byType = new Map(); + for (const profile of sortedProfiles) { + const ids = byType.get(profile.type) ?? []; + ids.push(profile.id); + byType.set(profile.type, ids); + } + + return { profiles, defaults, order, byType }; +} + +export function resolveModel({ + registry, + modelId, + outputType, + availableProviderIds, + preferredProviderIds, +}: ResolveModelOptions): ResolveModelResult { + const available = new Set(asArray(availableProviderIds)); + const preferred = new Set(asArray(preferredProviderIds)); + const hasAvailableFilter = available.size > 0; + const hasPreferredFilter = preferred.size > 0; + + const isAllowed = (providerId: string) => { + const profile = registry.profiles.get(providerId); + if (!profile?.enabled) { + return false; + } + if (hasAvailableFilter && !available.has(providerId)) { + return false; + } + if (hasPreferredFilter && !preferred.has(providerId)) { + return false; + } + return true; + }; + + const prefixed = modelId ? parseModelPrefix(registry, modelId) : null; + if (prefixed) { + return { + rawModelId: modelId, + modelId: prefixed.modelId, + explicitProviderId: prefixed.providerId, + candidateProviderIds: isAllowed(prefixed.providerId) + ? [prefixed.providerId] + : [], + }; + } + + const defaultProviderId = + outputType && outputType !== ModelOutputType.Rerank + ? registry.defaults[outputType] + : undefined; + + const fallbackOrder = [ + ...(defaultProviderId ? [defaultProviderId] : []), + registry.defaults.fallback, + ...registry.order, + ].filter((id): id is string => !!id); + + return { + rawModelId: modelId, + modelId, + candidateProviderIds: unique( + fallbackOrder.filter(providerId => isAllowed(providerId)) + ), + }; +} + +export function stripProviderPrefix( + registry: CopilotProviderRegistry, + providerId: string, + modelId?: string +) { + if (!modelId) { + return modelId; + } + const prefixed = parseModelPrefix(registry, modelId); + if (!prefixed) { + return modelId; + } + if (prefixed.providerId !== providerId) { + return modelId; + } + return prefixed.modelId; +} diff --git a/packages/backend/server/src/plugins/copilot/providers/provider.ts b/packages/backend/server/src/plugins/copilot/providers/provider.ts index b021854959..5ffab6818b 100644 --- a/packages/backend/server/src/plugins/copilot/providers/provider.ts +++ b/packages/backend/server/src/plugins/copilot/providers/provider.ts @@ -1,6 +1,7 @@ +import { AsyncLocalStorage } from 'node:async_hooks'; + import { Inject, Injectable, Logger } from '@nestjs/common'; import { ModuleRef } from '@nestjs/core'; -import { Tool, ToolSet } from 'ai'; import { z } from 'zod'; import { @@ -13,6 +14,7 @@ import { DocReader, DocWriter } from '../../../core/doc'; import { AccessController } from '../../../core/permission'; import { Models } from '../../../models'; import { IndexerService } from '../../indexer'; +import type { ProviderMiddlewareConfig } from '../config'; import { CopilotContextService } from '../context/service'; import { PromptService } from '../prompt/service'; import { @@ -24,6 +26,8 @@ import { buildDocSearchGetter, buildDocUpdateHandler, buildDocUpdateMetaHandler, + type CopilotTool, + type CopilotToolSet, createBlobReadTool, createCodeArtifactTool, createConversationSummaryTool, @@ -39,7 +43,10 @@ import { createExaSearchTool, createSectionEditTool, } from '../tools'; +import { canonicalizePromptAttachment } from './attachments'; import { CopilotProviderFactory } from './factory'; +import { resolveProviderMiddleware } from './provider-middleware'; +import { buildProviderRegistry } from './provider-registry'; import { type CopilotChatOptions, CopilotChatTools, @@ -47,22 +54,30 @@ import { type CopilotImageOptions, CopilotProviderModel, CopilotProviderType, + type CopilotRerankRequest, CopilotStructuredOptions, EmbeddingMessage, + type ModelAttachmentCapability, ModelCapability, ModelConditions, ModelFullConditions, ModelInputType, + ModelOutputType, + type PromptAttachmentKind, + type PromptAttachmentSourceKind, type PromptMessage, PromptMessageSchema, StreamObject, } from './types'; +const providerProfileContext = new AsyncLocalStorage(); + @Injectable() export abstract class CopilotProvider { protected readonly logger = new Logger(this.constructor.name); protected readonly MAX_STEPS = 20; protected onlineModelList: string[] = []; + abstract readonly type: CopilotProviderType; abstract readonly models: CopilotProviderModel[]; abstract configured(): boolean; @@ -70,8 +85,39 @@ export abstract class CopilotProvider { @Inject() protected readonly AFFiNEConfig!: Config; @Inject() protected readonly factory!: CopilotProviderFactory; @Inject() protected readonly moduleRef!: ModuleRef; + readonly #registeredProviderIds = new Set(); + + runWithProfile(providerId: string, callback: () => T): T { + return providerProfileContext.run(providerId, callback); + } + + protected getActiveProviderId() { + return providerProfileContext.getStore() ?? `${this.type}-default`; + } + + protected getActiveProviderMiddleware(): ProviderMiddlewareConfig { + const providerId = this.getActiveProviderId(); + const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers); + const profile = registry.profiles.get(providerId); + return profile?.middleware ?? resolveProviderMiddleware(this.type); + } + + protected metricLabels( + model: string, + labels: Record = {} + ) { + const providerId = this.getActiveProviderId(); + return { model, providerId, ...labels }; + } get config(): C { + const profileId = providerProfileContext.getStore(); + if (profileId) { + const profile = this.AFFiNEConfig.copilot.providers.profiles?.find( + profile => profile.id === profileId && profile.type === this.type + ); + if (profile) return profile.config as C; + } return this.AFFiNEConfig.copilot.providers[this.type] as C; } @@ -88,20 +134,199 @@ export abstract class CopilotProvider { } protected setup() { - if (this.configured()) { - this.factory.register(this); - if (env.selfhosted) { + const registry = buildProviderRegistry(this.AFFiNEConfig.copilot.providers); + const providerIds = registry.byType.get(this.type) ?? []; + const nextProviderIds = new Set(); + + for (const id of providerIds) { + const configured = this.runWithProfile(id, () => this.configured()); + if (configured) { + nextProviderIds.add(id); + this.factory.register(id, this); + } else { + this.factory.unregister(id, this); + } + } + + for (const providerId of this.#registeredProviderIds) { + if (!nextProviderIds.has(providerId)) { + this.factory.unregister(providerId, this); + } + } + this.#registeredProviderIds.clear(); + for (const providerId of nextProviderIds) { + this.#registeredProviderIds.add(providerId); + } + + if (env.selfhosted && nextProviderIds.size > 0) { + const [providerId] = Array.from(nextProviderIds); + this.runWithProfile(providerId, () => { this.refreshOnlineModels().catch(e => this.logger.error('Failed to refresh online models', e) ); - } - } else { - this.factory.unregister(this); + }); } } async refreshOnlineModels() {} + private unique(values: Iterable) { + return Array.from(new Set(values)); + } + + private attachmentKindToInputType( + kind: PromptAttachmentKind + ): ModelInputType { + switch (kind) { + case 'image': + return ModelInputType.Image; + case 'audio': + return ModelInputType.Audio; + default: + return ModelInputType.File; + } + } + + protected async inferModelConditionsFromMessages( + messages?: PromptMessage[], + withAttachment = true + ): Promise> { + if (!messages?.length || !withAttachment) return {}; + + const attachmentKinds: PromptAttachmentKind[] = []; + const attachmentSourceKinds: PromptAttachmentSourceKind[] = []; + const inputTypes: ModelInputType[] = []; + let hasRemoteAttachments = false; + + for (const message of messages) { + if (!Array.isArray(message.attachments)) continue; + + for (const attachment of message.attachments) { + const normalized = await canonicalizePromptAttachment( + attachment, + message + ); + attachmentKinds.push(normalized.kind); + inputTypes.push(this.attachmentKindToInputType(normalized.kind)); + attachmentSourceKinds.push(normalized.sourceKind); + hasRemoteAttachments = hasRemoteAttachments || normalized.isRemote; + } + } + + return { + ...(attachmentKinds.length + ? { attachmentKinds: this.unique(attachmentKinds) } + : {}), + ...(attachmentSourceKinds.length + ? { attachmentSourceKinds: this.unique(attachmentSourceKinds) } + : {}), + ...(inputTypes.length ? { inputTypes: this.unique(inputTypes) } : {}), + ...(hasRemoteAttachments ? { hasRemoteAttachments } : {}), + }; + } + + private mergeModelConditions( + cond: ModelFullConditions, + inferredCond: Partial + ): ModelFullConditions { + return { + ...inferredCond, + ...cond, + inputTypes: this.unique([ + ...(inferredCond.inputTypes ?? []), + ...(cond.inputTypes ?? []), + ]), + attachmentKinds: this.unique([ + ...(inferredCond.attachmentKinds ?? []), + ...(cond.attachmentKinds ?? []), + ]), + attachmentSourceKinds: this.unique([ + ...(inferredCond.attachmentSourceKinds ?? []), + ...(cond.attachmentSourceKinds ?? []), + ]), + hasRemoteAttachments: + cond.hasRemoteAttachments ?? inferredCond.hasRemoteAttachments, + }; + } + + protected getAttachCapability( + model: CopilotProviderModel, + outputType: ModelOutputType + ): ModelAttachmentCapability | undefined { + const capability = + model.capabilities.find(cap => cap.output.includes(outputType)) ?? + model.capabilities[0]; + if (!capability) { + return; + } + return this.resolveAttachmentCapability(capability, outputType); + } + + private resolveAttachmentCapability( + cap: ModelCapability, + outputType?: ModelOutputType + ): ModelAttachmentCapability | undefined { + if (outputType === ModelOutputType.Structured) { + return cap.structuredAttachments ?? cap.attachments; + } + return cap.attachments; + } + + private matchesAttachCapability( + cap: ModelCapability, + cond: ModelFullConditions + ) { + const { + attachmentKinds, + attachmentSourceKinds, + hasRemoteAttachments, + outputType, + } = cond; + + if ( + !attachmentKinds?.length && + !attachmentSourceKinds?.length && + !hasRemoteAttachments + ) { + return true; + } + + const attachmentCapability = this.resolveAttachmentCapability( + cap, + outputType + ); + if (!attachmentCapability) { + return !attachmentKinds?.some( + kind => !cap.input.includes(this.attachmentKindToInputType(kind)) + ); + } + + if ( + attachmentKinds?.some(kind => !attachmentCapability.kinds.includes(kind)) + ) { + return false; + } + + if ( + attachmentSourceKinds?.length && + attachmentCapability.sourceKinds?.length && + attachmentSourceKinds.some( + kind => !attachmentCapability.sourceKinds?.includes(kind) + ) + ) { + return false; + } + + if ( + hasRemoteAttachments && + attachmentCapability.allowRemoteUrls === false + ) { + return false; + } + + return true; + } + private findValidModel( cond: ModelFullConditions ): CopilotProviderModel | undefined { @@ -109,7 +334,8 @@ export abstract class CopilotProvider { const matcher = (cap: ModelCapability) => (!outputType || cap.output.includes(outputType)) && (!inputTypes?.length || - inputTypes.every(type => cap.input.includes(type))); + inputTypes.every(type => cap.input.includes(type))) && + this.matchesAttachCapability(cap, cond); if (modelId) { const hasOnlineModel = this.onlineModelList.includes(modelId); @@ -152,7 +378,7 @@ export abstract class CopilotProvider { protected getProviderSpecificTools( _toolName: CopilotChatTools, _model: string - ): [string, Tool?] | undefined { + ): [string, CopilotTool?] | undefined { return; } @@ -160,8 +386,8 @@ export abstract class CopilotProvider { protected async getTools( options: CopilotChatOptions, model: string - ): Promise { - const tools: ToolSet = {}; + ): Promise { + const tools: CopilotToolSet = {}; if (options?.tools?.length) { this.logger.debug(`getTools: ${JSON.stringify(options.tools)}`); const ac = this.moduleRef.get(AccessController, { strict: false }); @@ -316,19 +542,14 @@ export abstract class CopilotProvider { messages, embeddings, options = {}, + withAttachment = true, }: { cond: ModelFullConditions; messages?: PromptMessage[]; embeddings?: string[]; - options?: CopilotChatOptions; - }) { - const model = this.selectModel(cond); - const multimodal = model.capabilities.some(c => - [ModelInputType.Image, ModelInputType.Audio].some(t => - c.input.includes(t) - ) - ); - + options?: CopilotChatOptions | CopilotStructuredOptions; + withAttachment?: boolean; + }): Promise { if (messages) { const { requireContent = true, requireAttachment = false } = options; @@ -341,20 +562,56 @@ export abstract class CopilotProvider { }) .passthrough() .catchall(z.union([z.string(), z.number(), z.date(), z.null()])) - .refine( - m => - !(multimodal && requireAttachment && m.role === 'user') || - (m.attachments ? m.attachments.length > 0 : true), - { message: 'attachments required in multimodal mode' } - ) ) .optional(); this.handleZodError(MessageSchema.safeParse(messages)); + + const inferredCond = await this.inferModelConditionsFromMessages( + messages, + withAttachment + ); + const mergedCond = this.mergeModelConditions(cond, inferredCond); + const model = this.selectModel(mergedCond); + const multimodal = model.capabilities.some(c => + [ModelInputType.Image, ModelInputType.Audio, ModelInputType.File].some( + t => c.input.includes(t) + ) + ); + + if ( + multimodal && + requireAttachment && + !messages.some( + message => + message.role === 'user' && + Array.isArray(message.attachments) && + message.attachments.length > 0 + ) + ) { + throw new CopilotPromptInvalid( + 'attachments required in multimodal mode' + ); + } + + if (embeddings) { + this.handleZodError(EmbeddingMessage.safeParse(embeddings)); + } + + return mergedCond; } + + const inferredCond = await this.inferModelConditionsFromMessages( + messages, + withAttachment + ); + const mergedCond = this.mergeModelConditions(cond, inferredCond); + if (embeddings) { this.handleZodError(EmbeddingMessage.safeParse(embeddings)); } + + return mergedCond; } abstract text( @@ -415,7 +672,7 @@ export abstract class CopilotProvider { async rerank( _model: ModelConditions, - _messages: PromptMessage[][], + _request: CopilotRerankRequest, _options?: CopilotChatOptions ): Promise { throw new CopilotProviderNotSupported({ diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index 2a06e832f3..9da70b97a4 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -124,14 +124,97 @@ export const ChatMessageRole = Object.values(AiPromptRole) as [ 'user', ]; +const AttachmentUrlSchema = z.string().refine(value => { + if (value.startsWith('data:')) { + return true; + } + + try { + const url = new URL(value); + return ( + url.protocol === 'http:' || + url.protocol === 'https:' || + url.protocol === 'gs:' + ); + } catch { + return false; + } +}, 'attachments must use https?://, gs:// or data: urls'); + +export const PromptAttachmentSourceKindSchema = z.enum([ + 'url', + 'data', + 'bytes', + 'file_handle', +]); + +export const PromptAttachmentKindSchema = z.enum(['image', 'audio', 'file']); + +const AttachmentProviderHintSchema = z + .object({ + provider: z.nativeEnum(CopilotProviderType).optional(), + kind: PromptAttachmentKindSchema.optional(), + }) + .strict(); + +const PromptAttachmentSchema = z.discriminatedUnion('kind', [ + z + .object({ + kind: z.literal('url'), + url: AttachmentUrlSchema, + mimeType: z.string().optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('data'), + data: z.string(), + mimeType: z.string(), + encoding: z.enum(['base64', 'utf8']).optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('bytes'), + data: z.string(), + mimeType: z.string(), + encoding: z.literal('base64').optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), + z + .object({ + kind: z.literal('file_handle'), + fileHandle: z.string().trim().min(1), + mimeType: z.string().optional(), + fileName: z.string().optional(), + providerHint: AttachmentProviderHintSchema.optional(), + }) + .strict(), +]); + export const ChatMessageAttachment = z.union([ - z.string().url(), + AttachmentUrlSchema, z.object({ - attachment: z.string(), + attachment: AttachmentUrlSchema, mimeType: z.string(), }), + PromptAttachmentSchema, ]); +export const PromptResponseFormatSchema = z + .object({ + type: z.literal('json_schema'), + schema: z.any(), + strict: z.boolean().optional(), + }) + .strict(); + export const StreamObjectSchema = z.discriminatedUnion('type', [ z.object({ type: z.literal('text-delta'), @@ -161,6 +244,7 @@ export const PureMessageSchema = z.object({ streamObjects: z.array(StreamObjectSchema).optional().nullable(), attachments: z.array(ChatMessageAttachment).optional().nullable(), params: z.record(z.any()).optional().nullable(), + responseFormat: PromptResponseFormatSchema.optional().nullable(), }); export const PromptMessageSchema = PureMessageSchema.extend({ @@ -169,6 +253,12 @@ export const PromptMessageSchema = PureMessageSchema.extend({ export type PromptMessage = z.infer; export type PromptParams = NonNullable; export type StreamObject = z.infer; +export type PromptAttachment = z.infer; +export type PromptAttachmentSourceKind = z.infer< + typeof PromptAttachmentSourceKindSchema +>; +export type PromptAttachmentKind = z.infer; +export type PromptResponseFormat = z.infer; // ========== options ========== @@ -194,7 +284,9 @@ export type CopilotChatTools = NonNullable< >[number]; export const CopilotStructuredOptionsSchema = - CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema).optional(); + CopilotProviderOptionsSchema.merge(PromptConfigStrictSchema) + .extend({ schema: z.any().optional(), strict: z.boolean().optional() }) + .optional(); export type CopilotStructuredOptions = z.infer< typeof CopilotStructuredOptionsSchema @@ -220,10 +312,22 @@ export type CopilotEmbeddingOptions = z.infer< typeof CopilotEmbeddingOptionsSchema >; +export type CopilotRerankCandidate = { + id?: string; + text: string; +}; + +export type CopilotRerankRequest = { + query: string; + candidates: CopilotRerankCandidate[]; + topK?: number; +}; + export enum ModelInputType { Text = 'text', Image = 'image', Audio = 'audio', + File = 'file', } export enum ModelOutputType { @@ -231,12 +335,21 @@ export enum ModelOutputType { Object = 'object', Embedding = 'embedding', Image = 'image', + Rerank = 'rerank', Structured = 'structured', } +export interface ModelAttachmentCapability { + kinds: PromptAttachmentKind[]; + sourceKinds?: PromptAttachmentSourceKind[]; + allowRemoteUrls?: boolean; +} + export interface ModelCapability { input: ModelInputType[]; output: ModelOutputType[]; + attachments?: ModelAttachmentCapability; + structuredAttachments?: ModelAttachmentCapability; defaultForOutputType?: boolean; } @@ -248,6 +361,9 @@ export interface CopilotProviderModel { export type ModelConditions = { inputTypes?: ModelInputType[]; + attachmentKinds?: PromptAttachmentKind[]; + attachmentSourceKinds?: PromptAttachmentSourceKind[]; + hasRemoteAttachments?: boolean; modelId?: string; }; diff --git a/packages/backend/server/src/plugins/copilot/providers/utils.ts b/packages/backend/server/src/plugins/copilot/providers/utils.ts index a8b3dabe80..c20959adf9 100644 --- a/packages/backend/server/src/plugins/copilot/providers/utils.ts +++ b/packages/backend/server/src/plugins/copilot/providers/utils.ts @@ -1,34 +1,39 @@ -import { GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'; -import { GoogleVertexAnthropicProviderSettings } from '@ai-sdk/google-vertex/anthropic'; import { Logger } from '@nestjs/common'; -import { - CoreAssistantMessage, - CoreUserMessage, - FilePart, - ImagePart, - TextPart, - TextStreamPart, -} from 'ai'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; -import z, { ZodType } from 'zod'; +import z from 'zod'; -import { - bufferToArrayBuffer, - fetchBuffer, - OneMinute, - ResponseTooLargeError, - safeFetch, - SsrfBlockedError, -} from '../../../base'; -import { CustomAITools } from '../tools'; -import { PromptMessage, StreamObject } from './types'; +import { OneMinute, safeFetch } from '../../../base'; +import { PromptAttachment, StreamObject } from './types'; -type ChatMessage = CoreUserMessage | CoreAssistantMessage; +export type VertexProviderConfig = { + location?: string; + project?: string; + baseURL?: string; + googleAuthOptions?: GoogleAuthOptions; + fetch?: typeof fetch; +}; + +export type VertexAnthropicProviderConfig = VertexProviderConfig; + +type CopilotTextStreamPart = + | { type: 'text-delta'; text: string; id?: string } + | { type: 'reasoning-delta'; text: string; id?: string } + | { + type: 'tool-call'; + toolCallId: string; + toolName: string; + input: Record; + } + | { + type: 'tool-result'; + toolCallId: string; + toolName: string; + input: Record; + output: unknown; + } + | { type: 'error'; error: unknown }; -const ATTACHMENT_MAX_BYTES = 20 * 1024 * 1024; const ATTACH_HEAD_PARAMS = { timeoutMs: OneMinute / 12, maxRedirects: 3 }; - -const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; const FORMAT_INFER_MAP: Record = { pdf: 'application/pdf', mp3: 'audio/mpeg', @@ -53,9 +58,39 @@ const FORMAT_INFER_MAP: Record = { flv: 'video/flv', }; -async function fetchArrayBuffer(url: string): Promise { - const { buffer } = await fetchBuffer(url, ATTACHMENT_MAX_BYTES); - return bufferToArrayBuffer(buffer); +function toBase64Data(data: string, encoding: 'base64' | 'utf8' = 'base64') { + return encoding === 'base64' + ? data + : Buffer.from(data, 'utf8').toString('base64'); +} + +export function promptAttachmentToUrl( + attachment: PromptAttachment +): string | undefined { + if (typeof attachment === 'string') return attachment; + if ('attachment' in attachment) return attachment.attachment; + switch (attachment.kind) { + case 'url': + return attachment.url; + case 'data': + return `data:${attachment.mimeType};base64,${toBase64Data( + attachment.data, + attachment.encoding + )}`; + case 'bytes': + return `data:${attachment.mimeType};base64,${attachment.data}`; + case 'file_handle': + return; + } +} + +export function promptAttachmentMimeType( + attachment: PromptAttachment, + fallbackMimeType?: string +): string | undefined { + if (typeof attachment === 'string') return fallbackMimeType; + if ('attachment' in attachment) return attachment.mimeType; + return attachment.mimeType ?? fallbackMimeType; } export async function inferMimeType(url: string) { @@ -69,344 +104,49 @@ export async function inferMimeType(url: string) { if (ext) { return ext; } - try { - const mimeType = await safeFetch( - url, - { method: 'HEAD' }, - ATTACH_HEAD_PARAMS - ).then(res => res.headers.get('content-type')); - if (mimeType) return mimeType; - } catch { - // ignore and fallback to default - } + } + try { + const mimeType = await safeFetch( + url, + { method: 'HEAD' }, + ATTACH_HEAD_PARAMS + ).then(res => res.headers.get('content-type')); + if (mimeType) return mimeType; + } catch { + // ignore and fallback to default } return 'application/octet-stream'; } -export async function chatToGPTMessage( - messages: PromptMessage[], - // TODO(@darkskygit): move this logic in interface refactoring - withAttachment: boolean = true, - // NOTE: some providers in vercel ai sdk are not able to handle url attachments yet - // so we need to use base64 encoded attachments instead - useBase64Attachment: boolean = false -): Promise<[string | undefined, ChatMessage[], ZodType?]> { - const system = messages[0]?.role === 'system' ? messages.shift() : undefined; - const schema = - system?.params?.schema && system.params.schema instanceof ZodType - ? system.params.schema - : undefined; +type CitationIndexedEvent = { + type: 'citation'; + index: number; + url: string; +}; - // filter redundant fields - const msgs: ChatMessage[] = []; - for (let { role, content, attachments, params } of messages.filter( - m => m.role !== 'system' - )) { - content = content.trim(); - role = role as 'user' | 'assistant'; - const mimetype = params?.mimetype; - if (Array.isArray(attachments)) { - const contents: (TextPart | ImagePart | FilePart)[] = []; - if (content.length) { - contents.push({ type: 'text', text: content }); - } +export class CitationFootnoteFormatter { + private readonly citations = new Map(); - if (withAttachment) { - for (let attachment of attachments) { - let mediaType: string; - if (typeof attachment === 'string') { - mediaType = - typeof mimetype === 'string' - ? mimetype - : await inferMimeType(attachment); - } else { - ({ attachment, mimeType: mediaType } = attachment); - } - if (SIMPLE_IMAGE_URL_REGEX.test(attachment)) { - const data = - attachment.startsWith('data:') || useBase64Attachment - ? await fetchArrayBuffer(attachment).catch(error => { - // Avoid leaking internal details for blocked URLs. - if ( - error instanceof SsrfBlockedError || - error instanceof ResponseTooLargeError - ) { - throw new Error('Attachment URL is not allowed'); - } - throw error; - }) - : new URL(attachment); - if (mediaType.startsWith('image/')) { - contents.push({ type: 'image', image: data, mediaType }); - } else { - contents.push({ type: 'file' as const, data, mediaType }); - } - } - } - } else if (!content.length) { - // temp fix for pplx - contents.push({ type: 'text', text: '[no content]' }); - } - - msgs.push({ role, content: contents } as ChatMessage); - } else { - msgs.push({ role, content }); + public consume(event: CitationIndexedEvent) { + if (event.type !== 'citation') { + return ''; } - } - - return [system?.content, msgs, schema]; -} - -// pattern types the callback will receive -type Pattern = - | { kind: 'index'; value: number } // [123] - | { kind: 'link'; text: string; url: string } // [text](url) - | { kind: 'wrappedLink'; text: string; url: string }; // ([text](url)) - -type NeedMore = { kind: 'needMore' }; -type Failed = { kind: 'fail'; nextPos: number }; -type Finished = - | { kind: 'ok'; endPos: number; text: string; url: string } - | { kind: 'index'; endPos: number; value: number }; -type ParseStatus = Finished | NeedMore | Failed; - -type PatternCallback = (m: Pattern) => string; - -export class StreamPatternParser { - #buffer = ''; - - constructor(private readonly callback: PatternCallback) {} - - write(chunk: string): string { - this.#buffer += chunk; - const output: string[] = []; - let i = 0; - - while (i < this.#buffer.length) { - const ch = this.#buffer[i]; - - // [[[number]]] or [text](url) or ([text](url)) - if (ch === '[' || (ch === '(' && this.peek(i + 1) === '[')) { - const isWrapped = ch === '('; - const startPos = isWrapped ? i + 1 : i; - const res = this.tryParse(startPos); - if (res.kind === 'needMore') break; - const { output: out, nextPos } = this.handlePattern( - res, - isWrapped, - startPos, - i - ); - output.push(out); - i = nextPos; - continue; - } - output.push(ch); - i += 1; - } - - this.#buffer = this.#buffer.slice(i); - return output.join(''); - } - - end(): string { - const rest = this.#buffer; - this.#buffer = ''; - return rest; - } - - // =========== helpers =========== - - private peek(pos: number): string | undefined { - return pos < this.#buffer.length ? this.#buffer[pos] : undefined; - } - - private tryParse(pos: number): ParseStatus { - const nestedRes = this.tryParseNestedIndex(pos); - if (nestedRes) return nestedRes; - return this.tryParseBracketPattern(pos); - } - - private tryParseNestedIndex(pos: number): ParseStatus | null { - if (this.peek(pos + 1) !== '[') return null; - - let i = pos; - let bracketCount = 0; - - while (i < this.#buffer.length && this.#buffer[i] === '[') { - bracketCount++; - i++; - } - - if (bracketCount >= 2) { - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let content = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ']') { - content += this.#buffer[i++]; - } - - let rightBracketCount = 0; - while (i < this.#buffer.length && this.#buffer[i] === ']') { - rightBracketCount++; - i++; - } - - if (i >= this.#buffer.length && rightBracketCount < bracketCount) { - return { kind: 'needMore' }; - } - - if ( - rightBracketCount === bracketCount && - content.length > 0 && - this.isNumeric(content) - ) { - if (this.peek(i) === '(') { - return { kind: 'fail', nextPos: i }; - } - return { kind: 'index', endPos: i, value: Number(content) }; - } - } - - return null; - } - - private tryParseBracketPattern(pos: number): ParseStatus { - let i = pos + 1; // skip '[' - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let content = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ']') { - const nextChar = this.#buffer[i]; - if (nextChar === '[') { - return { kind: 'fail', nextPos: i }; - } - content += nextChar; - i += 1; - } - - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - const after = i + 1; - const afterChar = this.peek(after); - - if (content.length > 0 && this.isNumeric(content) && afterChar !== '(') { - // [number] pattern - return { kind: 'index', endPos: after, value: Number(content) }; - } else if (afterChar !== '(') { - // [text](url) pattern - return { kind: 'fail', nextPos: after }; - } - - i = after + 1; // skip '(' - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - - let url = ''; - while (i < this.#buffer.length && this.#buffer[i] !== ')') { - url += this.#buffer[i++]; - } - if (i >= this.#buffer.length) { - return { kind: 'needMore' }; - } - return { kind: 'ok', endPos: i + 1, text: content, url }; - } - - private isNumeric(str: string): boolean { - return !Number.isNaN(Number(str)) && str.trim() !== ''; - } - - private handlePattern( - pattern: Finished | Failed, - isWrapped: boolean, - start: number, - current: number - ): { output: string; nextPos: number } { - if (pattern.kind === 'fail') { - return { - output: this.#buffer.slice(current, pattern.nextPos), - nextPos: pattern.nextPos, - }; - } - - if (isWrapped) { - const afterLinkPos = pattern.endPos; - if (this.peek(afterLinkPos) !== ')') { - if (afterLinkPos >= this.#buffer.length) { - return { output: '', nextPos: current }; - } - return { output: '(', nextPos: start }; - } - - const out = - pattern.kind === 'index' - ? this.callback({ ...pattern, kind: 'index' }) - : this.callback({ ...pattern, kind: 'wrappedLink' }); - return { output: out, nextPos: afterLinkPos + 1 }; - } else { - const out = - pattern.kind === 'ok' - ? this.callback({ ...pattern, kind: 'link' }) - : this.callback({ ...pattern, kind: 'index' }); - return { output: out, nextPos: pattern.endPos }; - } - } -} - -export class CitationParser { - private readonly citations: string[] = []; - - private readonly parser = new StreamPatternParser(p => { - switch (p.kind) { - case 'index': { - if (p.value <= this.citations.length) { - return `[^${p.value}]`; - } - return `[${p.value}]`; - } - case 'wrappedLink': { - const index = this.citations.indexOf(p.url); - if (index === -1) { - this.citations.push(p.url); - return `[^${this.citations.length}]`; - } - return `[^${index + 1}]`; - } - case 'link': { - return `[${p.text}](${p.url})`; - } - } - }); - - public push(citation: string) { - this.citations.push(citation); - } - - public parse(content: string) { - return this.parser.write(content); + this.citations.set(event.index, event.url); + return ''; } public end() { - return this.parser.end() + '\n' + this.getFootnotes(); - } - - private getFootnotes() { - const footnotes = this.citations.map((citation, index) => { - return `[^${index + 1}]: {"type":"url","url":"${encodeURIComponent( - citation - )}"}`; - }); + const footnotes = Array.from(this.citations.entries()) + .sort((a, b) => a[0] - b[0]) + .map( + ([index, citation]) => + `[^${index}]: {"type":"url","url":"${encodeURIComponent(citation)}"}` + ); return footnotes.join('\n'); } } -type ChunkType = TextStreamPart['type']; +type ChunkType = CopilotTextStreamPart['type']; export function toError(error: unknown): Error { if (typeof error === 'string') { @@ -428,6 +168,14 @@ type DocEditFootnote = { intent: string; result: string; }; + +function asRecord(value: unknown): Record | null { + if (value && typeof value === 'object' && !Array.isArray(value)) { + return value as Record; + } + return null; +} + export class TextStreamParser { private readonly logger = new Logger(TextStreamParser.name); private readonly CALLOUT_PREFIX = '\n[!]\n'; @@ -438,7 +186,7 @@ export class TextStreamParser { private readonly docEditFootnotes: DocEditFootnote[] = []; - public parse(chunk: TextStreamPart) { + public parse(chunk: CopilotTextStreamPart) { let result = ''; switch (chunk.type) { case 'text-delta': { @@ -487,7 +235,7 @@ export class TextStreamParser { } case 'doc_edit': { this.docEditFootnotes.push({ - intent: chunk.input.instructions, + intent: String(chunk.input.instructions ?? ''), result: '', }); break; @@ -503,14 +251,12 @@ export class TextStreamParser { result = this.addPrefix(result); switch (chunk.toolName) { case 'doc_edit': { - const array = - chunk.output && typeof chunk.output === 'object' - ? chunk.output.result - : undefined; + const output = asRecord(chunk.output); + const array = output?.result; if (Array.isArray(array)) { result += array .map(item => { - return `\n${item.changedContent}\n`; + return `\n${String(asRecord(item)?.changedContent ?? '')}\n`; }) .join(''); this.docEditFootnotes[this.docEditFootnotes.length - 1].result = @@ -527,8 +273,11 @@ export class TextStreamParser { } else if (typeof output === 'string') { result += `\n${output}\n`; } else { + const message = asRecord(output)?.message; this.logger.warn( - `Unexpected result type for doc_semantic_search: ${output?.message || 'Unknown error'}` + `Unexpected result type for doc_semantic_search: ${ + typeof message === 'string' ? message : 'Unknown error' + }` ); } break; @@ -542,9 +291,11 @@ export class TextStreamParser { break; } case 'doc_compose': { - const output = chunk.output; - if (output && typeof output === 'object' && 'title' in output) { - result += `\nDocument "${output.title}" created successfully with ${output.wordCount} words.\n`; + const output = asRecord(chunk.output); + if (output && typeof output.title === 'string') { + result += `\nDocument "${output.title}" created successfully with ${String( + output.wordCount ?? 0 + )} words.\n`; } break; } @@ -624,7 +375,7 @@ export class TextStreamParser { } export class StreamObjectParser { - public parse(chunk: TextStreamPart) { + public parse(chunk: CopilotTextStreamPart) { switch (chunk.type) { case 'reasoning-delta': { return { type: 'reasoning' as const, textDelta: chunk.text }; @@ -703,21 +454,37 @@ export const VertexModelListSchema = z.object({ ), }); +function normalizeUrl(baseURL?: string) { + if (!baseURL?.trim()) { + return undefined; + } + try { + const url = new URL(baseURL); + const serialized = url.toString(); + if (serialized.endsWith('/')) return serialized.slice(0, -1); + return serialized; + } catch { + return undefined; + } +} + +export function getVertexAnthropicBaseUrl(options: VertexProviderConfig) { + const normalizedBaseUrl = normalizeUrl(options.baseURL); + if (normalizedBaseUrl) return normalizedBaseUrl; + const { location, project } = options; + if (!location || !project) return undefined; + return `https://${location}-aiplatform.googleapis.com/v1/projects/${project}/locations/${location}/publishers/anthropic`; +} + export async function getGoogleAuth( - options: GoogleVertexAnthropicProviderSettings | GoogleVertexProviderSettings, + options: VertexProviderConfig, publisher: 'anthropic' | 'google' ) { function getBaseUrl() { - const { baseURL, location } = options; - if (baseURL?.trim()) { - try { - const url = new URL(baseURL); - if (url.pathname.endsWith('/')) { - url.pathname = url.pathname.slice(0, -1); - } - return url.toString(); - } catch {} - } else if (location) { + const normalizedBaseUrl = normalizeUrl(options.baseURL); + if (normalizedBaseUrl) return normalizedBaseUrl; + const { location } = options; + if (location) { return `https://${location}-aiplatform.googleapis.com/v1beta1/publishers/${publisher}`; } return undefined; @@ -729,7 +496,7 @@ export async function getGoogleAuth( } const auth = new GoogleAuth({ scopes: ['https://www.googleapis.com/auth/cloud-platform'], - ...(options.googleAuthOptions as GoogleAuthOptions), + ...options.googleAuthOptions, }); const client = await auth.getClient(); const token = await client.getAccessToken(); diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 2dc036610c..95dbc29add 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -4,7 +4,6 @@ import { BadRequestException, NotFoundException } from '@nestjs/common'; import { Args, Field, - Float, ID, InputType, Mutation, @@ -15,7 +14,6 @@ import { ResolveField, Resolver, } from '@nestjs/graphql'; -import { AiPromptRole } from '@prisma/client'; import { GraphQLJSON, SafeIntResolver } from 'graphql-scalars'; import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; @@ -26,6 +24,7 @@ import { CopilotProviderSideError, CopilotSessionNotFound, type FileUpload, + ImageFormatNotSupported, paginate, Paginated, PaginationInput, @@ -41,6 +40,7 @@ import { DocReader } from '../../core/doc'; import { AccessController, DocAction } from '../../core/permission'; import { UserType } from '../../core/user'; import type { ListSessionOptions, UpdateChatSession } from '../../models'; +import { processImage } from '../../native'; import { CopilotCronJobs } from './cron'; import { PromptService } from './prompt/service'; import { CopilotProviderFactory } from './providers/factory'; @@ -50,6 +50,7 @@ import { CopilotStorage } from './storage'; import { type ChatHistory, type ChatMessage, SubmittedMessage } from './types'; export const COPILOT_LOCKER = 'copilot'; +const COPILOT_IMAGE_MAX_EDGE = 1536; // ================== Input Types ================== @@ -313,57 +314,6 @@ class CopilotQuotaType { used!: number; } -registerEnumType(AiPromptRole, { - name: 'CopilotPromptMessageRole', -}); - -@InputType('CopilotPromptConfigInput') -@ObjectType() -class CopilotPromptConfigType { - @Field(() => Float, { nullable: true }) - frequencyPenalty!: number | null; - - @Field(() => Float, { nullable: true }) - presencePenalty!: number | null; - - @Field(() => Float, { nullable: true }) - temperature!: number | null; - - @Field(() => Float, { nullable: true }) - topP!: number | null; -} - -@InputType('CopilotPromptMessageInput') -@ObjectType() -class CopilotPromptMessageType { - @Field(() => AiPromptRole) - role!: AiPromptRole; - - @Field(() => String) - content!: string; - - @Field(() => GraphQLJSON, { nullable: true }) - params!: Record | null; -} - -@ObjectType() -class CopilotPromptType { - @Field(() => String) - name!: string; - - @Field(() => String) - model!: string; - - @Field(() => String, { nullable: true }) - action!: string | null; - - @Field(() => CopilotPromptConfigType, { nullable: true }) - config!: CopilotPromptConfigType | null; - - @Field(() => [CopilotPromptMessageType]) - messages!: CopilotPromptMessageType[]; -} - @ObjectType() class CopilotModelType { @Field(() => String) @@ -638,13 +588,8 @@ export class CopilotResolver { ); } - @Mutation(() => String, { - description: 'Create a chat session', - }) - @CallMetric('ai', 'chat_session_create') - async createCopilotSession( - @CurrentUser() user: CurrentUser, - @Args({ name: 'options', type: () => CreateChatSessionInput }) + private async createCopilotSessionInternal( + user: CurrentUser, options: CreateChatSessionInput ): Promise { // permission check based on session type @@ -666,6 +611,42 @@ export class CopilotResolver { }); } + @Mutation(() => String, { + description: 'Create a chat session', + deprecationReason: 'use `createCopilotSessionWithHistory` instead', + }) + @CallMetric('ai', 'chat_session_create') + async createCopilotSession( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => CreateChatSessionInput }) + options: CreateChatSessionInput + ): Promise { + return await this.createCopilotSessionInternal(user, options); + } + + @Mutation(() => CopilotHistoriesType, { + description: 'Create a chat session and return full session payload', + }) + @CallMetric('ai', 'chat_session_create_with_history') + async createCopilotSessionWithHistory( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => CreateChatSessionInput }) + options: CreateChatSessionInput + ): Promise { + const sessionId = await this.createCopilotSessionInternal(user, options); + const session = await this.chatSession.getSessionInfo(sessionId); + if (!session) { + throw new NotFoundException('Session not found'); + } + return { + ...session, + messages: session.messages.map(message => ({ + ...message, + id: message.id, + })) as ChatMessageType[], + }; + } + @Mutation(() => String, { description: 'Update a chat session', }) @@ -799,19 +780,35 @@ export class CopilotResolver { for (const blob of blobs) { const uploaded = await this.storage.handleUpload(user.id, blob); + const detectedMime = + sniffMime(uploaded.buffer, blob.mimetype)?.toLowerCase() || + blob.mimetype; + let attachmentBuffer = uploaded.buffer; + let attachmentMimeType = detectedMime; + + if (detectedMime.startsWith('image/')) { + try { + attachmentBuffer = await processImage( + uploaded.buffer, + COPILOT_IMAGE_MAX_EDGE, + true + ); + attachmentMimeType = 'image/webp'; + } catch { + throw new ImageFormatNotSupported({ format: detectedMime }); + } + } + const filename = createHash('sha256') - .update(uploaded.buffer) + .update(attachmentBuffer) .digest('base64url'); const attachment = await this.storage.put( user.id, workspaceId, filename, - uploaded.buffer + attachmentBuffer ); - attachments.push({ - attachment, - mimeType: sniffMime(uploaded.buffer, blob.mimetype) || blob.mimetype, - }); + attachments.push({ attachment, mimeType: attachmentMimeType }); } } @@ -939,31 +936,10 @@ export class UserCopilotResolver { } } -@InputType() -class CreateCopilotPromptInput { - @Field(() => String) - name!: string; - - @Field(() => String) - model!: string; - - @Field(() => String, { nullable: true }) - action!: string | null; - - @Field(() => CopilotPromptConfigType, { nullable: true }) - config!: CopilotPromptConfigType | null; - - @Field(() => [CopilotPromptMessageType]) - messages!: CopilotPromptMessageType[]; -} - @Admin() @Resolver(() => String) export class PromptsManagementResolver { - constructor( - private readonly cron: CopilotCronJobs, - private readonly promptService: PromptService - ) {} + constructor(private readonly cron: CopilotCronJobs) {} @Mutation(() => Boolean, { description: 'Trigger generate missing titles cron job', @@ -980,48 +956,4 @@ export class PromptsManagementResolver { await this.cron.triggerCleanupTrashedDocEmbeddings(); return true; } - - @Query(() => [CopilotPromptType], { - description: 'List all copilot prompts', - }) - async listCopilotPrompts() { - const prompts = await this.promptService.list(); - return prompts.filter( - p => - p.messages.length > 0 && - // ignore internal prompts - !p.name.startsWith('workflow:') && - !p.name.startsWith('debug:') && - !p.name.startsWith('chat:') && - !p.name.startsWith('action:') - ); - } - - @Mutation(() => CopilotPromptType, { - description: 'Create a copilot prompt', - }) - async createCopilotPrompt( - @Args({ type: () => CreateCopilotPromptInput, name: 'input' }) - input: CreateCopilotPromptInput - ) { - await this.promptService.set( - input.name, - input.model, - input.messages, - input.config - ); - return this.promptService.get(input.name); - } - - @Mutation(() => CopilotPromptType, { - description: 'Update a copilot prompt', - }) - async updateCopilotPrompt( - @Args('name') name: string, - @Args('messages', { type: () => [CopilotPromptMessageType] }) - messages: CopilotPromptMessageType[] - ) { - await this.promptService.update(name, { messages, modified: true }); - return this.promptService.get(name); - } } diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 9014620421..38e5c33eee 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -7,6 +7,7 @@ import { AiPromptRole } from '@prisma/client'; import { pick } from 'lodash-es'; import { + Config, CopilotActionTaken, CopilotMessageNotFound, CopilotPromptNotFound, @@ -30,12 +31,15 @@ import { SubscriptionPlan, SubscriptionStatus } from '../payment/types'; import { ChatMessageCache } from './message'; import { ChatPrompt } from './prompt/chat-prompt'; import { PromptService } from './prompt/service'; +import { promptAttachmentHasSource } from './providers/attachments'; import { CopilotProviderFactory } from './providers/factory'; +import { buildProviderRegistry } from './providers/provider-registry'; import { ModelOutputType, type PromptMessage, type PromptParams, } from './providers/types'; +import { promptAttachmentToUrl } from './providers/utils'; import { type ChatHistory, type ChatMessage, @@ -105,10 +109,31 @@ export class ChatSession implements AsyncDisposable { hasPayment: boolean, requestedModelId?: string ): Promise { + const config = this.moduleRef.get(Config, { strict: false }); + const registry = config + ? buildProviderRegistry(config.copilot.providers) + : null; const defaultModel = this.model; - const normalize = (m?: string) => - !!m && this.optionalModels.includes(m) ? m : defaultModel; - const isPro = (m?: string) => !!m && this.proModels.includes(m); + const normalizeModel = (modelId?: string) => { + if (!modelId) return modelId; + const separatorIndex = modelId.indexOf('/'); + if (separatorIndex <= 0) return modelId; + const providerId = modelId.slice(0, separatorIndex); + if (!registry?.profiles.has(providerId)) return modelId; + return modelId.slice(separatorIndex + 1); + }; + const inModelList = (models: string[], modelId?: string) => { + if (!modelId) return false; + return ( + models.includes(modelId) || + models.includes(normalizeModel(modelId) ?? '') + ); + }; + const normalize = (m?: string) => { + if (inModelList(this.optionalModels, m)) return m; + return defaultModel; + }; + const isPro = (m?: string) => inModelList(this.proModels, m); // try resolve payment subscription service lazily let paymentEnabled = hasPayment; @@ -132,10 +157,19 @@ export class ChatSession implements AsyncDisposable { } if (paymentEnabled && !isUserAIPro && isPro(requestedModelId)) { + if (!defaultModel) { + throw new CopilotSessionInvalidInput( + 'Model is required for AI subscription fallback' + ); + } return defaultModel; } - return normalize(requestedModelId); + const resolvedModel = normalize(requestedModelId); + if (!resolvedModel) { + throw new CopilotSessionInvalidInput('Model is required'); + } + return resolvedModel; } push(message: ChatMessage) { @@ -240,11 +274,7 @@ export class ChatSession implements AsyncDisposable { lastMessage.attachments || [], ] .flat() - .filter(v => - typeof v === 'string' - ? !!v.trim() - : v && v.attachment.trim() && v.mimeType - ); + .filter(v => promptAttachmentHasSource(v)); //insert all previous user message content before first user message finished.splice(firstUserMessageIndex, 0, ...messages); @@ -434,8 +464,8 @@ export class ChatSessionService { messages: preload.concat(messages).map(m => ({ ...m, attachments: m.attachments - ?.map(a => (typeof a === 'string' ? a : a.attachment)) - .filter(a => !!a), + ?.map(a => promptAttachmentToUrl(a)) + .filter((a): a is string => !!a), })), }; } else { diff --git a/packages/backend/server/src/plugins/copilot/tools/blob-read.ts b/packages/backend/server/src/plugins/copilot/tools/blob-read.ts index 1a556010c3..893e440954 100644 --- a/packages/backend/server/src/plugins/copilot/tools/blob-read.ts +++ b/packages/backend/server/src/plugins/copilot/tools/blob-read.ts @@ -1,9 +1,9 @@ import { Logger } from '@nestjs/common'; -import { tool } from 'ai'; import { z } from 'zod'; import { AccessController } from '../../../core/permission'; import { toolError } from './error'; +import { defineTool } from './tool'; import type { ContextSession, CopilotChatOptions } from './types'; const logger = new Logger('ContextBlobReadTool'); @@ -32,16 +32,22 @@ export const buildBlobContentGetter = ( return; } + const contextFile = context.files.find( + file => file.blobId === blobId || file.id === blobId + ); + const canonicalBlobId = contextFile?.blobId ?? blobId; + const targetFileId = contextFile?.id; const [file, blob] = await Promise.all([ - context?.getFileContent(blobId, chunk), - context?.getBlobContent(blobId, chunk), + targetFileId ? context.getFileContent(targetFileId, chunk) : undefined, + context.getBlobContent(canonicalBlobId, chunk), ]); const content = file?.trim() || blob?.trim(); - if (!content) { - return; - } + if (!content) return; + const info = contextFile + ? { fileName: contextFile.name, fileType: contextFile.mimeType } + : {}; - return { blobId, chunk, content }; + return { blobId: canonicalBlobId, chunk, content, ...info }; }; return getBlobContent; }; @@ -52,7 +58,7 @@ export const createBlobReadTool = ( chunk?: number ) => Promise ) => { - return tool({ + return defineTool({ description: 'Return the content and basic metadata of a single attachment identified by blobId; more inclined to use search tools rather than this tool.', inputSchema: z.object({ diff --git a/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts b/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts index d46b88f448..639567a84e 100644 --- a/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts +++ b/packages/backend/server/src/plugins/copilot/tools/code-artifact.ts @@ -1,8 +1,8 @@ import { Logger } from '@nestjs/common'; -import { tool } from 'ai'; import { z } from 'zod'; import { toolError } from './error'; +import { defineTool } from './tool'; import type { CopilotProviderFactory, PromptService } from './types'; const logger = new Logger('CodeArtifactTool'); @@ -16,7 +16,7 @@ export const createCodeArtifactTool = ( promptService: PromptService, factory: CopilotProviderFactory ) => { - return tool({ + return defineTool({ description: 'Generate a single-file HTML snippet (with inline