mirror of
https://github.com/toeverything/AFFiNE.git
synced 2026-05-08 22:07:32 +08:00
feat(server): refactor for byok (#14911)
This commit is contained in:
@@ -995,6 +995,26 @@
|
||||
"description": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>\n@default false",
|
||||
"default": false
|
||||
},
|
||||
"byok.enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to enable workspace BYOK.\n@default true",
|
||||
"default": true
|
||||
},
|
||||
"byok.allowedProviders": {
|
||||
"type": "array",
|
||||
"description": "The allowlist for workspace BYOK providers.\n@default [\"openai\",\"anthropic\",\"gemini\",\"fal\"]",
|
||||
"default": [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"fal"
|
||||
]
|
||||
},
|
||||
"byok.allowCustomEndpoint": {
|
||||
"type": "boolean",
|
||||
"description": "Whether workspace BYOK custom endpoints are accepted.\n@default false",
|
||||
"default": false
|
||||
},
|
||||
"providers.profiles": {
|
||||
"type": "array",
|
||||
"description": "The profile list for copilot providers.\n@default []",
|
||||
@@ -1071,13 +1091,6 @@
|
||||
},
|
||||
"default": {}
|
||||
},
|
||||
"providers.perplexity": {
|
||||
"type": "object",
|
||||
"description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\"}",
|
||||
"default": {
|
||||
"apiKey": ""
|
||||
}
|
||||
},
|
||||
"providers.anthropic": {
|
||||
"type": "object",
|
||||
"description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.anthropic.com/v1\"}",
|
||||
@@ -1121,11 +1134,6 @@
|
||||
},
|
||||
"default": {}
|
||||
},
|
||||
"providers.morph": {
|
||||
"type": "object",
|
||||
"description": "The config for the morph provider.\n@default {}",
|
||||
"default": {}
|
||||
},
|
||||
"unsplash": {
|
||||
"type": "object",
|
||||
"description": "The config for the unsplash key.\n@default {\"key\":\"\"}",
|
||||
|
||||
6
packages/backend/native/index.d.ts
vendored
6
packages/backend/native/index.d.ts
vendored
@@ -364,7 +364,7 @@ export interface ModelConditionsContract {
|
||||
}
|
||||
|
||||
export interface ModelRegistryMatchRequest {
|
||||
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'
|
||||
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
|
||||
cond: ModelConditionsContract
|
||||
}
|
||||
|
||||
@@ -373,7 +373,7 @@ export interface ModelRegistryMatchResponse {
|
||||
}
|
||||
|
||||
export interface ModelRegistryResolveRequest {
|
||||
backendKind?: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'
|
||||
backendKind?: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
|
||||
modelId: string
|
||||
}
|
||||
|
||||
@@ -388,7 +388,7 @@ export interface ModelRegistryRouteContract {
|
||||
}
|
||||
|
||||
export interface ModelRegistryVariantContract {
|
||||
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'
|
||||
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
|
||||
canonicalKey: string
|
||||
rawModelId: string
|
||||
displayName?: string
|
||||
|
||||
@@ -686,17 +686,6 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Apply Updates",
|
||||
"action": "Apply Updates",
|
||||
"model": "claude-sonnet-4-5@20250929",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"template": "\nYou are a Markdown document update engine.\n\nYou will be given:\n\n1. content: The original Markdown document\n - The content is structured into blocks.\n - Each block starts with a comment like <!-- block_id=... flavour=... --> and contains the block's content.\n - The content is {{content}}\n\n2. op: A description of the edit intention\n - This describes the semantic meaning of the edit, such as \"Bold the first paragraph\".\n - The op is {{op}}\n\n3. updates: A Markdown snippet\n - The updates is {{updates}}\n - This represents the block-level changes to apply to the original Markdown.\n - The update may:\n - **Replace** an existing block (same block_id, new content)\n - **Delete** block(s) using <!-- delete block BLOCK_ID -->\n - **Insert** new block(s) with a new unique block_id\n - When performing deletions, the update will include **surrounding context blocks** (or use <!-- existing blocks -->) to help you determine where and what to delete.\n\nYour task:\n- Apply the update in <updates> to the document in <code>, following the intent described in <op>.\n- Preserve all block_id and flavour comments.\n- Maintain the original block order unless the update clearly appends new blocks.\n- Do not remove or alter unrelated blocks.\n- Output only the fully updated Markdown content. Do not wrap the content in ```markdown.\n\n---\n\n✍️ Examples\n\n✅ Replacement (modifying an existing block)\n\n<code>\n<!-- block_id=101 flavour=paragraph -->\n## Introduction\n\n<!-- block_id=102 flavour=paragraph -->\nThis document provides an overview of the system architecture and its components.\n</code>\n\n<op>\nMake the introduction more formal.\n</op>\n\n<updates>\n<!-- block_id=102 flavour=paragraph -->\nThis document outlines the architectural design and individual components of the system in detail.\n</updates>\n\nExpected Output:\n<!-- block_id=101 flavour=paragraph -->\n## Introduction\n\n<!-- block_id=102 flavour=paragraph -->\nThis document outlines the architectural design and individual components of the system in detail.\n\n---\n\n➕ Insertion (adding new content)\n\n<code>\n<!-- block_id=201 flavour=paragraph -->\n# Project Summary\n\n<!-- block_id=202 flavour=paragraph -->\nThis project aims to build a collaborative text editing tool.\n</code>\n\n<op>\nAdd a disclaimer section at the end.\n</op>\n\n<updates>\n<!-- block_id=new-301 flavour=paragraph -->\n## Disclaimer\n\n<!-- block_id=new-302 flavour=paragraph -->\nThis document is subject to change. Do not distribute externally.\n</updates>\n\nExpected Output:\n<!-- block_id=201 flavour=paragraph -->\n# Project Summary\n\n<!-- block_id=202 flavour=paragraph -->\nThis project aims to build a collaborative text editing tool.\n\n<!-- block_id=new-301 flavour=paragraph -->\n## Disclaimer\n\n<!-- block_id=new-302 flavour=paragraph -->\nThis document is subject to change. Do not distribute externally.\n\n---\n\n❌ Deletion (removing blocks)\n\n<code>\n<!-- block_id=401 flavour=paragraph -->\n## Author\n\n<!-- block_id=402 flavour=paragraph -->\nWritten by the AI team at OpenResearch.\n\n<!-- block_id=403 flavour=paragraph -->\n## Experimental Section\n\n<!-- block_id=404 flavour=paragraph -->\nThe following section is still under development and may change without notice.\n\n<!-- block_id=405 flavour=paragraph -->\n## License\n\n<!-- block_id=406 flavour=paragraph -->\nThis document is licensed under CC BY-NC 4.0.\n</code>\n\n<op>\nRemove the experimental section.\n</op>\n\n<updates>\n<!-- delete block_id=403 -->\n<!-- delete block_id=404 -->\n</updates>\n\nExpected Output:\n<!-- block_id=401 flavour=paragraph -->\n## Author\n\n<!-- block_id=402 flavour=paragraph -->\nWritten by the AI team at OpenResearch.\n\n<!-- block_id=405 flavour=paragraph -->\n## License\n\n<!-- block_id=406 flavour=paragraph -->\nThis document is licensed under CC BY-NC 4.0.\n\n---\n\nNow apply the `updates` to the `content`, following the intent in `op`, and return the updated Markdown.\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Code Artifact",
|
||||
"model": "claude-sonnet-4-5@20250929",
|
||||
|
||||
@@ -319,7 +319,7 @@ pub struct RequestedModelMatchResponse {
|
||||
pub struct ModelRegistryResolveRequest {
|
||||
#[napi(
|
||||
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
|
||||
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'"
|
||||
'gemini_vertex' | 'fal' | 'anthropic_vertex'"
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub backend_kind: Option<String>,
|
||||
@@ -333,7 +333,7 @@ pub struct ModelRegistryResolveRequest {
|
||||
pub struct ModelRegistryMatchRequest {
|
||||
#[napi(
|
||||
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
|
||||
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'"
|
||||
'gemini_vertex' | 'fal' | 'anthropic_vertex'"
|
||||
)]
|
||||
pub backend_kind: String,
|
||||
pub cond: ModelConditionsContract,
|
||||
@@ -346,7 +346,7 @@ pub struct ModelRegistryMatchRequest {
|
||||
pub struct ModelRegistryVariantContract {
|
||||
#[napi(
|
||||
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
|
||||
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'"
|
||||
'gemini_vertex' | 'fal' | 'anthropic_vertex'"
|
||||
)]
|
||||
pub backend_kind: String,
|
||||
pub canonical_key: String,
|
||||
|
||||
@@ -7,7 +7,7 @@ pub(crate) use error::{
|
||||
STREAM_ABORTED_REASON, STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, callback_dispatch_failed_reason,
|
||||
invalid_arg,
|
||||
};
|
||||
pub(crate) use stream::emit_error_event;
|
||||
pub(crate) use stream::{emit_error_event, emit_provider_selected_event};
|
||||
pub use stream::{
|
||||
llm_dispatch_prepared_stream, llm_dispatch_tool_loop_stream, llm_dispatch_tool_loop_stream_prepared,
|
||||
llm_dispatch_tool_loop_stream_routed,
|
||||
|
||||
@@ -106,14 +106,18 @@ fn spawn_prepared_stream(
|
||||
if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON)
|
||||
);
|
||||
|
||||
if let Err(error) = result
|
||||
if let Err(error) = &result
|
||||
&& !aborted_in_worker.load(Ordering::Relaxed)
|
||||
&& !callback_dispatch_failed
|
||||
&& !is_abort_error(&error)
|
||||
&& !is_abort_error(error)
|
||||
{
|
||||
emit_error_event(&callback, error.to_string(), "dispatch_error");
|
||||
}
|
||||
|
||||
if let Ok(provider_id) = result {
|
||||
emit_provider_selected_event(&callback, provider_id);
|
||||
}
|
||||
|
||||
if !callback_dispatch_failed {
|
||||
let _ = callback.call(
|
||||
Ok(STREAM_END_MARKER.to_string()),
|
||||
@@ -129,7 +133,7 @@ fn dispatch_prepared_stream_with_fallback(
|
||||
routes: &[PreparedDispatchRoute],
|
||||
callback: &ThreadsafeFunction<String, ()>,
|
||||
aborted: &AtomicBool,
|
||||
) -> std::result::Result<(), BackendError> {
|
||||
) -> std::result::Result<String, BackendError> {
|
||||
dispatch_prepared_stream_with_fallback_using_client(&DefaultHttpClient::default(), routes, aborted, |event| {
|
||||
emit_stream_event(callback, event)
|
||||
})
|
||||
@@ -140,7 +144,7 @@ fn dispatch_prepared_stream_with_fallback_using_client<F>(
|
||||
routes: &[PreparedDispatchRoute],
|
||||
aborted: &AtomicBool,
|
||||
mut emit_event: F,
|
||||
) -> std::result::Result<(), BackendError>
|
||||
) -> std::result::Result<String, BackendError>
|
||||
where
|
||||
F: FnMut(&StreamEvent) -> Status,
|
||||
{
|
||||
@@ -154,7 +158,7 @@ where
|
||||
.collect::<std::result::Result<Vec<_>, BackendError>>()?;
|
||||
let mut callback_dispatch_failed = false;
|
||||
|
||||
dispatch_prepared_stream_with_pipeline(
|
||||
let provider_id = dispatch_prepared_stream_with_pipeline(
|
||||
client,
|
||||
&mut adapter_routes,
|
||||
|| aborted.load(Ordering::Relaxed),
|
||||
@@ -174,7 +178,7 @@ where
|
||||
"{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:unknown"
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
Ok(provider_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,6 +199,16 @@ pub(crate) fn emit_error_event(callback: &ThreadsafeFunction<String, ()>, messag
|
||||
let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking);
|
||||
}
|
||||
|
||||
pub(crate) fn emit_provider_selected_event(callback: &ThreadsafeFunction<String, ()>, provider_id: String) {
|
||||
let event = serde_json::json!({
|
||||
"type": "provider_selected",
|
||||
"provider_id": provider_id,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let _ = callback.call(Ok(event), ThreadsafeFunctionCallMode::NonBlocking);
|
||||
}
|
||||
|
||||
fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status {
|
||||
let value = serde_json::to_string(event).unwrap_or_else(|error| {
|
||||
serde_json::json!({
|
||||
|
||||
@@ -14,7 +14,10 @@ use napi::{
|
||||
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
|
||||
};
|
||||
|
||||
use super::callback::{NapiEventSink, NapiToolExecutor, emit_tool_loop_event};
|
||||
use super::{
|
||||
super::emit_provider_selected_event,
|
||||
callback::{NapiEventSink, NapiToolExecutor, emit_tool_loop_event},
|
||||
};
|
||||
use crate::llm::{
|
||||
LlmDispatchPayload, LlmMiddlewarePayload, LlmStreamHandle, STREAM_ABORTED_REASON,
|
||||
STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, StreamPipeline, apply_request_middlewares,
|
||||
@@ -39,11 +42,13 @@ fn dispatch_prepared_round_with_fallback(
|
||||
})
|
||||
.collect::<std::result::Result<Vec<_>, BackendError>>()?;
|
||||
|
||||
run_prepared_stream_round_with_fallback(
|
||||
let mut selected_provider_id: Option<String> = None;
|
||||
let outcome = run_prepared_stream_round_with_fallback(
|
||||
&mut pipelines,
|
||||
|on_event| {
|
||||
let (selected_index, _) =
|
||||
let (selected_index, provider_id) =
|
||||
dispatch_prepared_stream_with_fallback_index(&DefaultHttpClient::default(), &adapter_routes, on_event)?;
|
||||
selected_provider_id = Some(provider_id);
|
||||
Ok(selected_index)
|
||||
},
|
||||
|| aborted.load(Ordering::Relaxed),
|
||||
@@ -53,7 +58,11 @@ fn dispatch_prepared_round_with_fallback(
|
||||
emitted.store(true, Ordering::Relaxed);
|
||||
emit_tool_loop_event(callback, loop_event)
|
||||
},
|
||||
)
|
||||
)?;
|
||||
if let Some(provider_id) = selected_provider_id {
|
||||
emit_provider_selected_event(callback, provider_id);
|
||||
}
|
||||
Ok(outcome)
|
||||
}
|
||||
|
||||
fn prepare_tool_loop_route(
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "ai_workspace_byok_configs" (
|
||||
"id" VARCHAR NOT NULL,
|
||||
"workspace_id" VARCHAR NOT NULL,
|
||||
"provider" VARCHAR NOT NULL,
|
||||
"name" VARCHAR NOT NULL,
|
||||
"description" VARCHAR,
|
||||
"encrypted_api_key" TEXT NOT NULL,
|
||||
"endpoint" TEXT,
|
||||
"sort_order" INTEGER NOT NULL DEFAULT 0,
|
||||
"enabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"disabled_reason" VARCHAR,
|
||||
"last_validated_at" TIMESTAMPTZ(3),
|
||||
"last_validation_error" TEXT,
|
||||
"last_used_at" TIMESTAMPTZ(3),
|
||||
"last_error_at" TIMESTAMPTZ(3),
|
||||
"last_error" TEXT,
|
||||
"created_by" VARCHAR,
|
||||
"updated_by" VARCHAR,
|
||||
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "ai_workspace_byok_configs_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ai_usage_events" (
|
||||
"id" VARCHAR NOT NULL,
|
||||
"workspace_id" VARCHAR NOT NULL,
|
||||
"user_id" VARCHAR,
|
||||
"provider" VARCHAR NOT NULL,
|
||||
"provider_source" VARCHAR NOT NULL,
|
||||
"feature_kind" VARCHAR NOT NULL,
|
||||
"model" VARCHAR,
|
||||
"session_id" VARCHAR,
|
||||
"task_id" VARCHAR,
|
||||
"action_id" VARCHAR,
|
||||
"billing_unit_id" VARCHAR,
|
||||
"prompt_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"completion_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"total_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"cached_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "ai_usage_events_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_workspace_byok_configs_workspace_id_idx" ON "ai_workspace_byok_configs"("workspace_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_workspace_byok_configs_workspace_id_provider_enabled_sor_idx" ON "ai_workspace_byok_configs"("workspace_id", "provider", "enabled", "sort_order");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "ai_workspace_byok_configs_workspace_id_provider_name_key" ON "ai_workspace_byok_configs"("workspace_id", "provider", "name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_usage_events_workspace_id_created_at_idx" ON "ai_usage_events"("workspace_id", "created_at");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_usage_events_workspace_id_provider_source_created_at_idx" ON "ai_usage_events"("workspace_id", "provider_source", "created_at");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_usage_events_feature_kind_created_at_idx" ON "ai_usage_events"("feature_kind", "created_at");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ai_usage_events_quota_exempt_idx" ON "ai_usage_events"("user_id", "provider_source", "feature_kind", "billing_unit_id");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ai_workspace_byok_configs" ADD CONSTRAINT "ai_workspace_byok_configs_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ai_usage_events" ADD CONSTRAINT "ai_usage_events_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -147,6 +147,8 @@ model Workspace {
|
||||
blobs Blob[]
|
||||
ignoredDocs AiWorkspaceIgnoredDocs[]
|
||||
embedFiles AiWorkspaceFiles[]
|
||||
byokConfigs AiWorkspaceByokConfig[]
|
||||
aiUsageEvents AiUsageEvent[]
|
||||
comments Comment[]
|
||||
commentAttachments CommentAttachment[]
|
||||
workspaceCalendars WorkspaceCalendar[]
|
||||
@@ -732,6 +734,62 @@ model AiWorkspaceBlobEmbedding {
|
||||
@@map("ai_workspace_blob_embeddings")
|
||||
}
|
||||
|
||||
model AiWorkspaceByokConfig {
|
||||
id String @id @default(uuid()) @db.VarChar
|
||||
workspaceId String @map("workspace_id") @db.VarChar
|
||||
provider String @db.VarChar
|
||||
name String @db.VarChar
|
||||
description String? @db.VarChar
|
||||
encryptedApiKey String @map("encrypted_api_key") @db.Text
|
||||
endpoint String? @db.Text
|
||||
sortOrder Int @default(0) @map("sort_order")
|
||||
enabled Boolean @default(true)
|
||||
disabledReason String? @map("disabled_reason") @db.VarChar
|
||||
lastValidatedAt DateTime? @map("last_validated_at") @db.Timestamptz(3)
|
||||
lastValidationError String? @map("last_validation_error") @db.Text
|
||||
lastUsedAt DateTime? @map("last_used_at") @db.Timestamptz(3)
|
||||
lastErrorAt DateTime? @map("last_error_at") @db.Timestamptz(3)
|
||||
lastError String? @map("last_error") @db.Text
|
||||
createdBy String? @map("created_by") @db.VarChar
|
||||
updatedBy String? @map("updated_by") @db.VarChar
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
|
||||
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([workspaceId, provider, name])
|
||||
@@index([workspaceId])
|
||||
@@index([workspaceId, provider, enabled, sortOrder])
|
||||
@@map("ai_workspace_byok_configs")
|
||||
}
|
||||
|
||||
model AiUsageEvent {
|
||||
id String @id @default(uuid()) @db.VarChar
|
||||
workspaceId String @map("workspace_id") @db.VarChar
|
||||
userId String? @map("user_id") @db.VarChar
|
||||
provider String @db.VarChar
|
||||
providerSource String @map("provider_source") @db.VarChar
|
||||
featureKind String @map("feature_kind") @db.VarChar
|
||||
model String? @db.VarChar
|
||||
sessionId String? @map("session_id") @db.VarChar
|
||||
taskId String? @map("task_id") @db.VarChar
|
||||
actionId String? @map("action_id") @db.VarChar
|
||||
billingUnitId String? @map("billing_unit_id") @db.VarChar
|
||||
promptTokens Int @default(0) @map("prompt_tokens")
|
||||
completionTokens Int @default(0) @map("completion_tokens")
|
||||
totalTokens Int @default(0) @map("total_tokens")
|
||||
cachedTokens Int @default(0) @map("cached_tokens")
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
|
||||
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@index([workspaceId, createdAt])
|
||||
@@index([workspaceId, providerSource, createdAt])
|
||||
@@index([featureKind, createdAt])
|
||||
@@index([userId, providerSource, featureKind, billingUnitId], map: "ai_usage_events_quota_exempt_idx")
|
||||
@@map("ai_usage_events")
|
||||
}
|
||||
|
||||
enum AiJobStatus {
|
||||
pending
|
||||
running
|
||||
|
||||
@@ -526,17 +526,6 @@ Generated by [AVA](https://avajs.dev).
|
||||
remoteAttachmentRequests: [],
|
||||
}
|
||||
|
||||
## PerplexityProvider should ignore attachments during text model matching
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
[
|
||||
{
|
||||
text: 'summarize this',
|
||||
type: 'text',
|
||||
},
|
||||
]
|
||||
|
||||
## GeminiVertexProvider should prefetch bearer token for native config
|
||||
|
||||
> Snapshot 1
|
||||
|
||||
1201
packages/backend/server/src/__tests__/copilot/byok.spec.ts
Normal file
1201
packages/backend/server/src/__tests__/copilot/byok.spec.ts
Normal file
File diff suppressed because it is too large
Load Diff
@@ -771,7 +771,7 @@ function actionRunRecord(
|
||||
};
|
||||
}
|
||||
|
||||
function installActionSessionMock(
|
||||
async function installActionSessionMock(
|
||||
t: ExecutionContext<Tester>,
|
||||
{
|
||||
actionId,
|
||||
@@ -786,8 +786,12 @@ function installActionSessionMock(
|
||||
const { models, session } = t.context;
|
||||
const sandbox = Sinon.createSandbox();
|
||||
const sessionId = `copilot-provider-action-${actionId}-${randomUUID()}`;
|
||||
const userId = `copilot-provider-user-${randomUUID()}`;
|
||||
const workspaceId = `copilot-provider-action-${actionId}`;
|
||||
const user = await models.user.create({
|
||||
email: `copilot-provider-user-${randomUUID()}@affine.test`,
|
||||
});
|
||||
const userId = user.id;
|
||||
const workspace = await models.workspace.create(userId);
|
||||
const workspaceId = workspace.id;
|
||||
const docId = `copilot-provider-action-${actionId}-doc`;
|
||||
const savedTurns: Array<{ role: string }> = [];
|
||||
const userTurn = {
|
||||
@@ -904,7 +908,11 @@ for (const { actionId, content, verifier } of actionRecipeCases) {
|
||||
}
|
||||
|
||||
const { sandbox, sessionId, userId, savedTurns } =
|
||||
installActionSessionMock(t, { actionId, actionPrompt, content });
|
||||
await installActionSessionMock(t, {
|
||||
actionId,
|
||||
actionPrompt,
|
||||
content,
|
||||
});
|
||||
|
||||
let result = '';
|
||||
try {
|
||||
@@ -976,8 +984,10 @@ for (const testCase of TRANSCRIPT_AUDIO_CASES) {
|
||||
runIfCopilotConfigured,
|
||||
async t => {
|
||||
const { models, transcript } = t.context;
|
||||
const userId = `copilot-provider-transcript-user-${randomUUID()}`;
|
||||
const workspaceId = `copilot-provider-transcript-workspace-${randomUUID()}`;
|
||||
const user = await models.user.create({
|
||||
email: `copilot-provider-transcript-${randomUUID()}@affine.pro`,
|
||||
});
|
||||
const workspace = await models.workspace.create(user.id);
|
||||
const blobId = `copilot-provider-transcript-blob-${randomUUID()}`;
|
||||
const payload = TranscriptPayloadSchema.parse({
|
||||
sourceAudio: { blobId, mimeType: testCase.mimeType },
|
||||
@@ -990,8 +1000,8 @@ for (const testCase of TRANSCRIPT_AUDIO_CASES) {
|
||||
],
|
||||
});
|
||||
const task = await models.copilotTranscriptTask.create({
|
||||
userId,
|
||||
workspaceId,
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
blobId,
|
||||
strategy: 'gemini',
|
||||
recipeId: 'transcript.audio.gemini',
|
||||
|
||||
@@ -139,9 +139,6 @@ test.before(async t => {
|
||||
fal: {
|
||||
apiKey: process.env.COPILOT_FAL_API_KEY ?? '1',
|
||||
},
|
||||
perplexity: {
|
||||
apiKey: process.env.COPILOT_PERPLEXITY_API_KEY ?? '1',
|
||||
},
|
||||
anthropic: {
|
||||
apiKey: process.env.COPILOT_ANTHROPIC_API_KEY ?? '1',
|
||||
},
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import test from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
import type { Models } from '../../models';
|
||||
import { type Models } from '../../models';
|
||||
import { CopilotAccessPolicy } from '../../plugins/copilot/access';
|
||||
import type { ByokFeatureKind } from '../../plugins/copilot/byok/types';
|
||||
import { HistoryAttachmentUrlProjector } from '../../plugins/copilot/compat/history-attachment-url-projector';
|
||||
import { CompatHistoryProjector } from '../../plugins/copilot/compat/history-projector';
|
||||
import { HistoryPromptPreloadProjector } from '../../plugins/copilot/compat/history-prompt-preload-projector';
|
||||
import { HistoryVisibilityPolicy } from '../../plugins/copilot/compat/history-visibility-policy';
|
||||
import { ConversationPolicy } from '../../plugins/copilot/conversation/policy';
|
||||
import type { Turn } from '../../plugins/copilot/core';
|
||||
import { CopilotEmbeddingClientService } from '../../plugins/copilot/embedding/client';
|
||||
import { CopilotProviderType } from '../../plugins/copilot/providers/types';
|
||||
import {
|
||||
@@ -29,9 +33,11 @@ import {
|
||||
AttachmentMaterializer,
|
||||
resolveAttachmentFetchUrl,
|
||||
} from '../../plugins/copilot/runtime/hosts/attachment-materializer';
|
||||
import { ConversationHost } from '../../plugins/copilot/runtime/hosts/conversation-host';
|
||||
import { ImageResultHost } from '../../plugins/copilot/runtime/hosts/image-result-host';
|
||||
import { ResponsePostprocessor } from '../../plugins/copilot/runtime/hosts/response-postprocessor';
|
||||
import { TurnPersistence } from '../../plugins/copilot/runtime/hosts/turn-persistence';
|
||||
import { ToolRuntime } from '../../plugins/copilot/runtime/tool-runtime';
|
||||
|
||||
function stubTurnPersistence(
|
||||
persistProjectedResult: Sinon.SinonStub = Sinon.stub().resolves(null)
|
||||
@@ -41,6 +47,367 @@ function stubTurnPersistence(
|
||||
} as unknown as TurnPersistence;
|
||||
}
|
||||
|
||||
function stubConversationSession(latestUserTurn?: unknown) {
|
||||
return {
|
||||
config: {
|
||||
sessionId: 'session-1',
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
},
|
||||
model: 'gpt-4o-mini',
|
||||
stashTurns: latestUserTurn ? [latestUserTurn] : [],
|
||||
latestUserTurn,
|
||||
revertLatestMessage: Sinon.stub(),
|
||||
};
|
||||
}
|
||||
|
||||
test('ConversationPolicy should treat zero quota limit as exhausted', async t => {
|
||||
const policy = new ConversationPolicy(
|
||||
{
|
||||
userFeature: { has: Sinon.stub().resolves(false) },
|
||||
copilotSession: { countUserMessages: Sinon.stub().resolves(0) },
|
||||
} as any,
|
||||
{
|
||||
getUserQuota: Sinon.stub().resolves({ copilotActionLimit: 0 }),
|
||||
} as any
|
||||
);
|
||||
|
||||
t.false(await policy.hasQuota('user-1'));
|
||||
await t.throwsAsync(policy.checkQuota('user-1'));
|
||||
});
|
||||
|
||||
type TurnRouteAccessCase = {
|
||||
name: string;
|
||||
profiles: Array<{ id: string }>;
|
||||
featureKind?: 'embedding' | 'rerank' | 'workspace_indexing';
|
||||
byokLeaseId?: string;
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
expectedQuotaCalls: number;
|
||||
expectedError?: string;
|
||||
expectedQuotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
const turnRouteAccessCases: TurnRouteAccessCase[] = [
|
||||
{
|
||||
name: 'checks quota when BYOK does not cover the route',
|
||||
profiles: [],
|
||||
expectedQuotaCalls: 1,
|
||||
expectedError: 'quota exceeded',
|
||||
},
|
||||
{
|
||||
name: 'skips quota when BYOK covers the route',
|
||||
profiles: [{ id: 'profile-1' }],
|
||||
byokLeaseId: 'lease-1',
|
||||
expectedQuotaCalls: 0,
|
||||
expectedQuotaBackedRoutesAllowed: undefined,
|
||||
},
|
||||
{
|
||||
name: 'preserves explicit quota-backed route disable override',
|
||||
profiles: [],
|
||||
quotaBackedRoutesAllowed: false,
|
||||
expectedQuotaCalls: 0,
|
||||
expectedQuotaBackedRoutesAllowed: false,
|
||||
},
|
||||
{
|
||||
name: 'does not check user quota for unmetered service features',
|
||||
profiles: [],
|
||||
featureKind: 'rerank',
|
||||
expectedQuotaCalls: 0,
|
||||
expectedQuotaBackedRoutesAllowed: true,
|
||||
},
|
||||
];
|
||||
|
||||
for (const matrixCase of turnRouteAccessCases) {
|
||||
test(`CopilotAccessPolicy resolve turn route access: ${matrixCase.name}`, async t => {
|
||||
const checkQuota = Sinon.stub().rejects(new Error('quota exceeded'));
|
||||
const getProfiles = Sinon.stub().resolves(matrixCase.profiles);
|
||||
const access = new CopilotAccessPolicy(
|
||||
{ checkQuota } as any,
|
||||
{ getProfiles } as any
|
||||
);
|
||||
|
||||
const promise = access.resolveTurnRouteAccess({
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
byokLeaseId: matrixCase.byokLeaseId,
|
||||
featureKind: matrixCase.featureKind,
|
||||
quotaBackedRoutesAllowed: matrixCase.quotaBackedRoutesAllowed,
|
||||
});
|
||||
|
||||
if (matrixCase.expectedError) {
|
||||
await t.throwsAsync(promise, { message: matrixCase.expectedError });
|
||||
} else {
|
||||
const routeAccess = await promise;
|
||||
t.is(
|
||||
routeAccess.quotaBackedRoutesAllowed,
|
||||
matrixCase.expectedQuotaBackedRoutesAllowed
|
||||
);
|
||||
}
|
||||
t.is(checkQuota.callCount, matrixCase.expectedQuotaCalls);
|
||||
if (matrixCase.expectedQuotaCalls) {
|
||||
Sinon.assert.calledWithExactly(checkQuota, 'user-1');
|
||||
}
|
||||
if (matrixCase.byokLeaseId) {
|
||||
Sinon.assert.calledWithMatch(getProfiles, {
|
||||
byokLeaseId: matrixCase.byokLeaseId,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
type ByokCoverageCase = {
|
||||
featureKind?: ByokFeatureKind;
|
||||
expected: { local: boolean; server: boolean };
|
||||
};
|
||||
|
||||
const byokCoverageCases: ByokCoverageCase[] = [
|
||||
{ featureKind: 'chat', expected: { local: true, server: true } },
|
||||
{ featureKind: 'action', expected: { local: true, server: true } },
|
||||
{ featureKind: 'image', expected: { local: true, server: true } },
|
||||
{ featureKind: 'transcript', expected: { local: false, server: true } },
|
||||
{ featureKind: 'embedding', expected: { local: false, server: true } },
|
||||
{
|
||||
featureKind: 'workspace_indexing',
|
||||
expected: { local: false, server: true },
|
||||
},
|
||||
{ featureKind: 'rerank', expected: { local: false, server: true } },
|
||||
{ expected: { local: true, server: true } },
|
||||
];
|
||||
|
||||
for (const matrixCase of byokCoverageCases) {
|
||||
test(`CopilotAccessPolicy should resolve BYOK coverage for ${matrixCase.featureKind ?? 'default'}`, async t => {
|
||||
const getProfiles = Sinon.stub().resolves([]);
|
||||
const access = new CopilotAccessPolicy(
|
||||
{ hasQuota: Sinon.stub().resolves(true) } as any,
|
||||
{ getProfiles } as any
|
||||
);
|
||||
|
||||
await access.getByokProfiles({
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: matrixCase.featureKind,
|
||||
});
|
||||
|
||||
t.like(getProfiles.firstCall.args[0], {
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
});
|
||||
t.is(getProfiles.firstCall.args[0].featureKind, matrixCase.featureKind);
|
||||
t.deepEqual(getProfiles.firstCall.args[1], matrixCase.expected);
|
||||
});
|
||||
}
|
||||
|
||||
test('CopilotAccessPolicy assertQuotaOrByok should honor quota-backed route disable', async t => {
|
||||
const checkQuota = Sinon.stub().resolves(undefined);
|
||||
const access = new CopilotAccessPolicy(
|
||||
{ checkQuota } as any,
|
||||
{ getProfiles: Sinon.stub().resolves([]) } as any
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
access.assertQuotaOrByok({
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'transcript',
|
||||
quotaBackedRoutesAllowed: false,
|
||||
})
|
||||
);
|
||||
Sinon.assert.notCalled(checkQuota);
|
||||
});
|
||||
|
||||
test('ConversationHost should delegate empty no-message stream access', async t => {
|
||||
const session = stubConversationSession();
|
||||
const resolveTurnRouteAccess = Sinon.stub().rejects(
|
||||
new Error('quota exceeded')
|
||||
);
|
||||
const host = new ConversationHost(
|
||||
{
|
||||
get: Sinon.stub().resolves(session),
|
||||
revertLatestMessage: Sinon.stub().resolves(undefined),
|
||||
} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
{ resolveTurnRouteAccess } as any
|
||||
);
|
||||
|
||||
await t.throwsAsync(host.prepareTurn('user-1', 'session-1', {}), {
|
||||
message: 'quota exceeded',
|
||||
});
|
||||
Sinon.assert.calledOnceWithMatch(resolveTurnRouteAccess, {
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
});
|
||||
});
|
||||
|
||||
test('ConversationHost should return access decision for empty no-message stream', async t => {
|
||||
const session = stubConversationSession();
|
||||
const resolveTurnRouteAccess = Sinon.stub().resolves({
|
||||
byokProfiles: [{ id: 'profile-1' }],
|
||||
quotaBackedRoutesAllowed: undefined,
|
||||
});
|
||||
const host = new ConversationHost(
|
||||
{
|
||||
get: Sinon.stub().resolves(session),
|
||||
revertLatestMessage: Sinon.stub().resolves(undefined),
|
||||
} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
{ resolveTurnRouteAccess } as any
|
||||
);
|
||||
|
||||
const prepared = await host.prepareTurn('user-1', 'session-1', {});
|
||||
|
||||
t.is(prepared.latestTurn, undefined);
|
||||
t.is(prepared.quotaBackedRoutesAllowed, undefined);
|
||||
Sinon.assert.calledOnce(resolveTurnRouteAccess);
|
||||
});
|
||||
|
||||
test('ConversationHost should replay accepted tokens without rechecking quota', async t => {
|
||||
const acceptedTurn: Turn = {
|
||||
id: 'turn-1',
|
||||
conversationId: 'session-1',
|
||||
role: 'user',
|
||||
content: 'hello',
|
||||
attachments: [],
|
||||
metadata: {},
|
||||
renderTrace: [],
|
||||
toolEvents: [],
|
||||
createdAt: new Date(),
|
||||
};
|
||||
const session = {
|
||||
...stubConversationSession(acceptedTurn),
|
||||
findTurn: Sinon.stub().withArgs('turn-1').returns(acceptedTurn),
|
||||
};
|
||||
const resolveTurnRouteAccess = Sinon.stub().rejects(
|
||||
new Error('quota exceeded')
|
||||
);
|
||||
const host = new ConversationHost(
|
||||
{
|
||||
get: Sinon.stub().resolves(session),
|
||||
revertLatestMessage: Sinon.stub().resolves(undefined),
|
||||
} as any,
|
||||
{
|
||||
getAccepted: Sinon.stub().resolves({
|
||||
sessionId: 'session-1',
|
||||
turnId: 'turn-1',
|
||||
}),
|
||||
} as any,
|
||||
{} as any,
|
||||
{ resolveTurnRouteAccess } as any
|
||||
);
|
||||
|
||||
const prepared = await host.prepareTurn('user-1', 'session-1', {
|
||||
messageId: 'message-1',
|
||||
});
|
||||
|
||||
t.is(prepared.latestTurn, acceptedTurn);
|
||||
t.true(prepared.quotaBackedRoutesAllowed);
|
||||
Sinon.assert.notCalled(resolveTurnRouteAccess);
|
||||
});
|
||||
|
||||
test('ConversationHost should replay durable tokens without rechecking quota', async t => {
|
||||
const durableTurn: Turn = {
|
||||
id: 'turn-1',
|
||||
conversationId: 'session-1',
|
||||
role: 'user',
|
||||
content: 'hello',
|
||||
attachments: [],
|
||||
metadata: {},
|
||||
renderTrace: [],
|
||||
toolEvents: [],
|
||||
createdAt: new Date(),
|
||||
};
|
||||
const session = {
|
||||
...stubConversationSession(durableTurn),
|
||||
findTurn: Sinon.stub().withArgs('turn-1').returns(durableTurn),
|
||||
pushPersistedTurn: Sinon.stub(),
|
||||
};
|
||||
const resolveTurnRouteAccess = Sinon.stub().rejects(
|
||||
new Error('quota exceeded')
|
||||
);
|
||||
const markAccepted = Sinon.stub().resolves(undefined);
|
||||
const host = new ConversationHost(
|
||||
{
|
||||
get: Sinon.stub().resolves(session),
|
||||
findTurnByCompatSubmissionId: Sinon.stub().resolves(durableTurn),
|
||||
revertLatestMessage: Sinon.stub().resolves(undefined),
|
||||
} as any,
|
||||
{
|
||||
getAccepted: Sinon.stub().resolves(undefined),
|
||||
markAccepted,
|
||||
} as any,
|
||||
{
|
||||
acquire: Sinon.stub().resolves({
|
||||
[Symbol.asyncDispose]: Sinon.stub().resolves(undefined),
|
||||
}),
|
||||
} as any,
|
||||
{ resolveTurnRouteAccess } as any
|
||||
);
|
||||
|
||||
const prepared = await host.prepareTurn('user-1', 'session-1', {
|
||||
messageId: 'message-1',
|
||||
});
|
||||
|
||||
t.is(prepared.latestTurn, durableTurn);
|
||||
t.true(prepared.quotaBackedRoutesAllowed);
|
||||
Sinon.assert.calledOnceWithMatch(markAccepted, 'message-1', {
|
||||
sessionId: 'session-1',
|
||||
turnId: 'turn-1',
|
||||
});
|
||||
Sinon.assert.notCalled(resolveTurnRouteAccess);
|
||||
});
|
||||
|
||||
test('ToolRuntime should pass route context into prompt-backed tools', async t => {
|
||||
const promptRuntime = {
|
||||
runText: Sinon.stub().resolves('<html><body>done</body></html>'),
|
||||
};
|
||||
const runtime = new ToolRuntime(
|
||||
{} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
{} as any,
|
||||
promptRuntime as any,
|
||||
{} as any
|
||||
);
|
||||
|
||||
const tools = await runtime.getTools(
|
||||
{
|
||||
tools: ['codeArtifact'],
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
workspace: 'workspace-1',
|
||||
byokLeaseId: 'lease-1',
|
||||
featureKind: 'chat',
|
||||
quotaBackedRoutesAllowed: false,
|
||||
},
|
||||
'gpt-4o-mini'
|
||||
);
|
||||
|
||||
const result = await tools.code_artifact.execute?.(
|
||||
{ title: 'Demo', userPrompt: 'build a page' },
|
||||
{}
|
||||
);
|
||||
|
||||
t.like(result as object, { title: 'Demo' });
|
||||
Sinon.assert.calledOnceWithMatch(
|
||||
promptRuntime.runText,
|
||||
'Code Artifact',
|
||||
{ content: 'build a page' },
|
||||
{
|
||||
providerOptions: {
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
workspace: 'workspace-1',
|
||||
byokLeaseId: 'lease-1',
|
||||
featureKind: 'chat',
|
||||
quotaBackedRoutesAllowed: false,
|
||||
},
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
test('ResponsePostprocessor should build text, object and image assistant turns', t => {
|
||||
const postprocessor = new ResponsePostprocessor();
|
||||
|
||||
@@ -267,7 +634,7 @@ test('action result projection should map image result url to assistant attachme
|
||||
t.deepEqual(turn?.attachments, ['https://example.com/final.png']);
|
||||
});
|
||||
|
||||
test('CopilotEmbeddingClientService should refresh configured client and clear unavailable client', async t => {
|
||||
test('CopilotEmbeddingClientService should keep dispatch client across global config refreshes', async t => {
|
||||
const taskPolicy = {
|
||||
resolveEmbeddingModelId: () => 'text-embedding-3-large',
|
||||
};
|
||||
@@ -288,8 +655,8 @@ test('CopilotEmbeddingClientService should refresh configured client and clear u
|
||||
t.truthy(service.getClient());
|
||||
|
||||
const second = await service.refresh();
|
||||
t.is(second, undefined);
|
||||
t.is(service.getClient(), undefined);
|
||||
t.truthy(second);
|
||||
t.is(service.getClient(), second);
|
||||
Sinon.assert.calledTwice(runtime.embeddingConfigured);
|
||||
Sinon.assert.alwaysCalledWithExactly(
|
||||
runtime.embeddingConfigured,
|
||||
@@ -297,6 +664,95 @@ test('CopilotEmbeddingClientService should refresh configured client and clear u
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotEmbeddingClientService should keep workspace-routed embedding client without global provider', async t => {
|
||||
const taskPolicy = {
|
||||
resolveEmbeddingModelId: () => 'gemini-embedding-001',
|
||||
resolveRerankModelId: () => 'gpt-4o-mini',
|
||||
};
|
||||
const runtime = {
|
||||
embeddingConfigured: Sinon.stub().resolves(false),
|
||||
};
|
||||
const service = new CopilotEmbeddingClientService(
|
||||
taskPolicy as any,
|
||||
runtime as any
|
||||
);
|
||||
|
||||
const client = await service.refresh();
|
||||
|
||||
t.truthy(client);
|
||||
t.is(service.getClient(), client);
|
||||
Sinon.assert.calledOnceWithExactly(
|
||||
runtime.embeddingConfigured,
|
||||
'gemini-embedding-001'
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotEmbeddingClientService should pass workspace context into embedding routes', async t => {
|
||||
const signal = new AbortController().signal;
|
||||
const taskPolicy = {
|
||||
resolveEmbeddingModelId: () => 'gemini-embedding-001',
|
||||
resolveRerankModelId: () => 'gpt-4o-mini',
|
||||
};
|
||||
const runtime = {
|
||||
embeddingConfigured: Sinon.stub().resolves(true),
|
||||
embed: Sinon.stub().resolves([[0.1]]),
|
||||
rerank: Sinon.stub().resolves([0.8]),
|
||||
};
|
||||
const service = new CopilotEmbeddingClientService(
|
||||
taskPolicy as any,
|
||||
runtime as any
|
||||
);
|
||||
const client = await service.refresh();
|
||||
|
||||
t.truthy(client);
|
||||
await client?.getEmbeddings(['hello'], {
|
||||
workspaceId: 'workspace-1',
|
||||
userId: 'user-1',
|
||||
featureKind: 'workspace_indexing',
|
||||
signal,
|
||||
});
|
||||
|
||||
Sinon.assert.calledOnceWithMatch(
|
||||
runtime.embed,
|
||||
'gemini-embedding-001',
|
||||
['hello'],
|
||||
{
|
||||
dimensions: Sinon.match.number,
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
featureKind: 'workspace_indexing',
|
||||
signal,
|
||||
}
|
||||
);
|
||||
|
||||
await client?.reRank(
|
||||
'hello',
|
||||
[{ chunk: 0, content: 'hello', distance: 0.2 }],
|
||||
1,
|
||||
{
|
||||
workspaceId: 'workspace-1',
|
||||
userId: 'user-1',
|
||||
featureKind: 'workspace_indexing',
|
||||
signal,
|
||||
}
|
||||
);
|
||||
|
||||
Sinon.assert.calledOnceWithMatch(
|
||||
runtime.rerank,
|
||||
'gpt-4o-mini',
|
||||
{
|
||||
query: 'hello',
|
||||
candidates: [{ id: '0', text: 'hello' }],
|
||||
},
|
||||
{
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
featureKind: 'rerank',
|
||||
signal,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
test('CompatHistoryProjector should compose visibility, prompt preload and attachment url projection', t => {
|
||||
const projector = new CompatHistoryProjector(
|
||||
new HistoryVisibilityPolicy(),
|
||||
|
||||
@@ -20,7 +20,6 @@ import {
|
||||
import { GeminiProvider } from '../../plugins/copilot/providers/gemini/gemini';
|
||||
import { GeminiVertexProvider } from '../../plugins/copilot/providers/gemini/vertex';
|
||||
import { OpenAIProvider } from '../../plugins/copilot/providers/openai';
|
||||
import { PerplexityProvider } from '../../plugins/copilot/providers/perplexity';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
type PromptMessage,
|
||||
@@ -589,16 +588,6 @@ class TestOpenAIProvider extends OpenAIProvider {
|
||||
}
|
||||
}
|
||||
|
||||
class TestPerplexityProvider extends PerplexityProvider {
|
||||
override get config() {
|
||||
return { apiKey: 'perplexity-key' };
|
||||
}
|
||||
|
||||
override configured() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
test('NativeProviderAdapter should append citation and attachment footnotes', async t => {
|
||||
const dispatch = () =>
|
||||
(async function* (): AsyncIterableIterator<LlmToolLoopStreamEvent> {
|
||||
@@ -818,6 +807,91 @@ test('NativeProviderAdapter streamObject should map tool and text events', async
|
||||
t.snapshot(events);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should finalize usage with selected provider', async t => {
|
||||
const usageEvents: Array<{
|
||||
providerId: string;
|
||||
model?: string;
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}> = [];
|
||||
const adapter = new NativeProviderAdapter(
|
||||
() =>
|
||||
stream(() => [
|
||||
{ type: 'message_start', model: 'gpt-5-mini' },
|
||||
{ type: 'text_delta', text: 'ok' },
|
||||
{
|
||||
type: 'done',
|
||||
finish_reason: 'stop',
|
||||
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
|
||||
},
|
||||
{
|
||||
type: 'provider_selected',
|
||||
provider_id: 'byok-aaaaaaaaaaaa-openai-server-key1',
|
||||
},
|
||||
]),
|
||||
{
|
||||
onUsage: input => {
|
||||
usageEvents.push(input);
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const events = await collectChunks(
|
||||
adapter.streamObject({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: nativeMessages(nativeUserText('hi')),
|
||||
})
|
||||
);
|
||||
|
||||
t.deepEqual(events, [{ type: 'text-delta', textDelta: 'ok' }]);
|
||||
t.deepEqual(usageEvents, [
|
||||
{
|
||||
providerId: 'byok-aaaaaaaaaaaa-openai-server-key1',
|
||||
model: 'gpt-5-mini',
|
||||
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test('NativeProviderAdapter streamObject should keep streaming when usage callback fails', async t => {
|
||||
const adapter = new NativeProviderAdapter(
|
||||
() =>
|
||||
stream(() => [
|
||||
{ type: 'message_start', model: 'gpt-5-mini' },
|
||||
{ type: 'text_delta', text: 'ok' },
|
||||
{
|
||||
type: 'done',
|
||||
finish_reason: 'stop',
|
||||
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
|
||||
},
|
||||
{
|
||||
type: 'provider_selected',
|
||||
provider_id: 'byok-aaaaaaaaaaaa-openai-server-key1',
|
||||
},
|
||||
]),
|
||||
{
|
||||
onUsage: () => {
|
||||
throw new Error('usage callback failed');
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const events = await collectChunks(
|
||||
adapter.streamObject({
|
||||
model: 'gpt-5-mini',
|
||||
stream: true,
|
||||
messages: nativeMessages(nativeUserText('hi')),
|
||||
})
|
||||
);
|
||||
|
||||
t.deepEqual(events, [{ type: 'text-delta', textDelta: 'ok' }]);
|
||||
});
|
||||
|
||||
test('NativeRuntimeAdapter streamObject should keep raw runtime stream objects only', async t => {
|
||||
const adapter = new NativeRuntimeAdapter(
|
||||
createTestToolLoopBridge(mockDispatch, {}, 3)
|
||||
@@ -1653,36 +1727,6 @@ test('GeminiProvider should not pass materialized inline attachment URL to nativ
|
||||
t.false('url' in (attachmentPart?.source ?? {}));
|
||||
});
|
||||
|
||||
test('PerplexityProvider should ignore attachments during text model matching', async t => {
|
||||
const provider = new TestPerplexityProvider();
|
||||
let capturedRequest: LlmRequest | undefined;
|
||||
|
||||
(provider as any).getActiveProviderMiddleware = () => ({});
|
||||
(provider as any).getTools = async () => ({});
|
||||
(provider as any).createNativeAdapter = () => ({
|
||||
text: async (request: LlmRequest) => {
|
||||
capturedRequest = request;
|
||||
return 'ok';
|
||||
},
|
||||
});
|
||||
|
||||
const result = await getProviderRuntimeHost(provider).run.text(
|
||||
{ modelId: 'sonar' },
|
||||
[
|
||||
{
|
||||
role: 'user',
|
||||
content: 'summarize this',
|
||||
attachments: ['https://example.com/a.pdf'],
|
||||
params: { mimetype: 'application/pdf' },
|
||||
},
|
||||
],
|
||||
{}
|
||||
);
|
||||
|
||||
t.is(result, 'ok');
|
||||
t.snapshot(capturedRequest?.messages[0]?.content);
|
||||
});
|
||||
|
||||
test('GeminiProvider should reject unsupported attachment schemes at input validation', async t => {
|
||||
const provider = new TestGeminiProvider();
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@ import test from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { CopilotPromptInvalid, NoCopilotProviderAvailable } from '../../base';
|
||||
import {
|
||||
CopilotPromptInvalid,
|
||||
CopilotQuotaExceeded,
|
||||
NoCopilotProviderAvailable,
|
||||
} from '../../base';
|
||||
import {
|
||||
type LlmBackendConfig,
|
||||
type LlmEmbeddingRequest,
|
||||
@@ -20,9 +24,11 @@ import {
|
||||
llmResolveRequestedModelMatch,
|
||||
type LlmStructuredRequest,
|
||||
} from '../../native';
|
||||
import type { ProviderMiddlewareConfig } from '../../plugins/copilot/config';
|
||||
import type {
|
||||
CopilotProviderProfile,
|
||||
ProviderMiddlewareConfig,
|
||||
} from '../../plugins/copilot/config';
|
||||
import { CopilotProviderFactory } from '../../plugins/copilot/providers/factory';
|
||||
import { MorphProvider } from '../../plugins/copilot/providers/morph';
|
||||
import { OpenAIProvider } from '../../plugins/copilot/providers/openai';
|
||||
import { CopilotProvider } from '../../plugins/copilot/providers/provider';
|
||||
import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry';
|
||||
@@ -62,6 +68,13 @@ import {
|
||||
userPrompt,
|
||||
} from './prompt-test-helper';
|
||||
|
||||
function createNativeExecutionEngine() {
|
||||
return new NativeExecutionEngine({
|
||||
recordUsage: Sinon.stub().resolves(),
|
||||
recordProviderFailure: Sinon.stub().resolves(),
|
||||
} as never);
|
||||
}
|
||||
|
||||
function structuredOptions(
|
||||
schema: z.ZodTypeAny,
|
||||
extra?: Record<string, unknown>
|
||||
@@ -541,7 +554,7 @@ test('CapabilityRuntime should defer no-route embedding plans to native engine',
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should expose execute/executeStream as the single plan entrypoints', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let dispatchCalls = 0;
|
||||
let streamCalls = 0;
|
||||
|
||||
@@ -660,6 +673,279 @@ test('NativeExecutionEngine should expose execute/executeStream as the single pl
|
||||
t.is(streamCalls, 1);
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should record BYOK usage when stream finalizes with selected provider', async t => {
|
||||
const byok = {
|
||||
recordUsage: Sinon.stub().resolves(),
|
||||
};
|
||||
const engine = new NativeExecutionEngine(byok as never);
|
||||
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
|
||||
|
||||
const originalStream = (serverNativeModule as any).llmDispatchPreparedStream;
|
||||
(serverNativeModule as any).llmDispatchPreparedStream = (
|
||||
_routesJson: string,
|
||||
callback: (error: Error | null, arg: string) => void
|
||||
) => {
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({
|
||||
type: 'message_start',
|
||||
model: 'gpt-5-mini',
|
||||
})
|
||||
);
|
||||
callback(null, JSON.stringify({ type: 'text_delta', text: 'ok' }));
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({
|
||||
type: 'done',
|
||||
finish_reason: 'stop',
|
||||
usage: {
|
||||
prompt_tokens: 2,
|
||||
completion_tokens: 3,
|
||||
total_tokens: 5,
|
||||
},
|
||||
})
|
||||
);
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({
|
||||
type: 'provider_selected',
|
||||
provider_id: providerId,
|
||||
})
|
||||
);
|
||||
callback(null, '__AFFINE_LLM_STREAM_END__');
|
||||
return { abort() {} };
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmDispatchPreparedStream = originalStream;
|
||||
});
|
||||
|
||||
const chunks = await collectAsync(
|
||||
engine.executeStream({
|
||||
nativeDispatch: {
|
||||
chat: {
|
||||
routes: [
|
||||
nativeRoute({
|
||||
providerId,
|
||||
authToken: 'byok-key',
|
||||
request: nativeTextRequest('hello'),
|
||||
}),
|
||||
],
|
||||
prepared: {
|
||||
route: preparedRoute({
|
||||
providerId,
|
||||
authToken: 'byok-key',
|
||||
}),
|
||||
request: nativeTextRequest('hello'),
|
||||
tools: {},
|
||||
postprocess: { nodeTextMiddleware: [] },
|
||||
},
|
||||
hasTools: false,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
kind: 'streamText',
|
||||
cond: { modelId: 'gpt-5-mini' },
|
||||
messages: singleUserPromptMessages('hello'),
|
||||
options: {
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
featureKind: 'chat',
|
||||
},
|
||||
},
|
||||
routePolicy: { fallbackOrder: [providerId] },
|
||||
runtimePolicy: {},
|
||||
attachmentPolicy: { materializeRemoteAttachments: true },
|
||||
responsePostprocess: { mode: 'streamText' },
|
||||
hostPersistence: { persistAssistantTurn: true, outputKind: 'streamText' },
|
||||
hostContext: {},
|
||||
})
|
||||
);
|
||||
|
||||
t.deepEqual(chunks, ['ok']);
|
||||
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
|
||||
workspaceId: 'workspace-1',
|
||||
userId: 'user-1',
|
||||
sessionId: 'session-1',
|
||||
featureKind: 'chat',
|
||||
providerId,
|
||||
model: 'gpt-5-mini',
|
||||
usage: {
|
||||
prompt_tokens: 2,
|
||||
completion_tokens: 3,
|
||||
total_tokens: 5,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should record plain text BYOK usage as chat by default', async t => {
|
||||
const byok = {
|
||||
recordUsage: Sinon.stub().resolves(),
|
||||
recordProviderFailure: Sinon.stub().resolves(),
|
||||
};
|
||||
const engine = new NativeExecutionEngine(byok as never);
|
||||
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
|
||||
|
||||
const originalDispatch = (serverNativeModule as any).llmDispatchPrepared;
|
||||
(serverNativeModule as any).llmDispatchPrepared = () => {
|
||||
return JSON.stringify({
|
||||
provider_id: providerId,
|
||||
response: {
|
||||
id: 'chat_execute',
|
||||
model: 'gpt-5-mini',
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'execute-ok' }],
|
||||
},
|
||||
usage: {
|
||||
prompt_tokens: 1,
|
||||
completion_tokens: 2,
|
||||
total_tokens: 3,
|
||||
},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
});
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmDispatchPrepared = originalDispatch;
|
||||
});
|
||||
|
||||
const text = await engine.execute({
|
||||
nativeDispatch: {
|
||||
chat: {
|
||||
routes: [
|
||||
nativeRoute({
|
||||
providerId,
|
||||
authToken: 'byok-key',
|
||||
request: nativeTextRequest('hello'),
|
||||
}),
|
||||
],
|
||||
prepared: {
|
||||
route: preparedRoute({
|
||||
providerId,
|
||||
authToken: 'byok-key',
|
||||
}),
|
||||
request: nativeTextRequest('hello'),
|
||||
tools: {},
|
||||
postprocess: { nodeTextMiddleware: [] },
|
||||
},
|
||||
hasTools: false,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
kind: 'text',
|
||||
cond: { modelId: 'gpt-5-mini' },
|
||||
messages: singleUserPromptMessages('hello'),
|
||||
options: {
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
},
|
||||
},
|
||||
routePolicy: { fallbackOrder: [providerId] },
|
||||
runtimePolicy: {},
|
||||
attachmentPolicy: { materializeRemoteAttachments: true },
|
||||
responsePostprocess: { mode: 'text' },
|
||||
hostPersistence: { persistAssistantTurn: true, outputKind: 'text' },
|
||||
hostContext: {},
|
||||
});
|
||||
|
||||
t.is(text, 'execute-ok');
|
||||
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
|
||||
workspaceId: 'workspace-1',
|
||||
userId: 'user-1',
|
||||
sessionId: 'session-1',
|
||||
featureKind: 'chat',
|
||||
providerId,
|
||||
model: 'gpt-5-mini',
|
||||
usage: {
|
||||
prompt_tokens: 1,
|
||||
completion_tokens: 2,
|
||||
total_tokens: 3,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should not fail stream when BYOK usage recording fails', async t => {
|
||||
const byok = {
|
||||
recordUsage: Sinon.stub().rejects(new Error('usage db down')),
|
||||
};
|
||||
const engine = new NativeExecutionEngine(byok as never);
|
||||
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
|
||||
|
||||
const originalStream = (serverNativeModule as any).llmDispatchPreparedStream;
|
||||
(serverNativeModule as any).llmDispatchPreparedStream = (
|
||||
_routesJson: string,
|
||||
callback: (error: Error | null, arg: string) => void
|
||||
) => {
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({ type: 'message_start', model: 'gpt-5-mini' })
|
||||
);
|
||||
callback(null, JSON.stringify({ type: 'text_delta', text: 'ok' }));
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({
|
||||
type: 'done',
|
||||
finish_reason: 'stop',
|
||||
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
|
||||
})
|
||||
);
|
||||
callback(
|
||||
null,
|
||||
JSON.stringify({ type: 'provider_selected', provider_id: providerId })
|
||||
);
|
||||
callback(null, '__AFFINE_LLM_STREAM_END__');
|
||||
return { abort() {} };
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmDispatchPreparedStream = originalStream;
|
||||
});
|
||||
|
||||
const chunks = await collectAsync(
|
||||
engine.executeStream({
|
||||
nativeDispatch: {
|
||||
chat: {
|
||||
routes: [
|
||||
nativeRoute({
|
||||
providerId,
|
||||
authToken: 'byok-key',
|
||||
request: nativeTextRequest('hello'),
|
||||
}),
|
||||
],
|
||||
prepared: {
|
||||
route: preparedRoute({ providerId, authToken: 'byok-key' }),
|
||||
request: nativeTextRequest('hello'),
|
||||
tools: {},
|
||||
postprocess: { nodeTextMiddleware: [] },
|
||||
},
|
||||
hasTools: false,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
kind: 'streamText',
|
||||
cond: { modelId: 'gpt-5-mini' },
|
||||
messages: singleUserPromptMessages('hello'),
|
||||
options: {
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
featureKind: 'chat',
|
||||
},
|
||||
},
|
||||
routePolicy: { fallbackOrder: [providerId] },
|
||||
runtimePolicy: {},
|
||||
attachmentPolicy: { materializeRemoteAttachments: true },
|
||||
responsePostprocess: { mode: 'streamText' },
|
||||
hostPersistence: { persistAssistantTurn: true, outputKind: 'streamText' },
|
||||
hostContext: {},
|
||||
})
|
||||
);
|
||||
|
||||
t.deepEqual(chunks, ['ok']);
|
||||
Sinon.assert.calledOnce(byok.recordUsage);
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should return no prepared routes when native prepare returns null', async t => {
|
||||
const provider = new DriverOnlyProvider();
|
||||
(provider as any).AFFiNEConfig = { copilot: { providers: { openai: {} } } };
|
||||
@@ -693,9 +979,16 @@ test('CopilotProviderFactory should return no prepared routes when native prepar
|
||||
enableFeature: Sinon.stub(),
|
||||
disableFeature: Sinon.stub(),
|
||||
};
|
||||
const access = {
|
||||
resolveRouteAccess: Sinon.stub().resolves({
|
||||
byokProfiles: [],
|
||||
quotaBackedRoutesAvailable: true,
|
||||
}),
|
||||
};
|
||||
const factory = new CopilotProviderFactory(
|
||||
server as never,
|
||||
registryService as never
|
||||
registryService as never,
|
||||
access as never
|
||||
);
|
||||
factory.register('openai-main', provider);
|
||||
|
||||
@@ -923,52 +1216,6 @@ test('driver-only provider should require explicit structured response contracts
|
||||
t.is(capturedRequest, undefined);
|
||||
});
|
||||
|
||||
test('MorphProvider should reuse the base native chat driver template', async t => {
|
||||
const provider = new MorphProvider();
|
||||
(provider as any).AFFiNEConfig = {
|
||||
copilot: { providers: { morph: { apiKey: 'test-key' } } },
|
||||
};
|
||||
(provider as any).toolExecutorHost = {
|
||||
createNativeAdapter: () => ({
|
||||
text: async () => 'morph text',
|
||||
streamText: async function* () {
|
||||
yield 'morph stream';
|
||||
},
|
||||
streamObject: async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'unused' };
|
||||
},
|
||||
}),
|
||||
getTools: async () => ({}),
|
||||
};
|
||||
|
||||
t.is(
|
||||
await getProviderRuntimeHost(provider).run.text(
|
||||
{ modelId: 'morph-v3-fast' },
|
||||
promptMessages(userPrompt('hello'))
|
||||
),
|
||||
'morph text'
|
||||
);
|
||||
t.deepEqual(
|
||||
await collectAsync(
|
||||
getProviderRuntimeHost(provider).run.streamText(
|
||||
{ modelId: 'morph-v3-fast' },
|
||||
promptMessages(userPrompt('hello'))
|
||||
)
|
||||
),
|
||||
['morph stream']
|
||||
);
|
||||
t.is(
|
||||
await getProviderRuntimeHost(provider).prepare.chat(
|
||||
'streamObject',
|
||||
{
|
||||
modelId: 'morph-v3-fast',
|
||||
},
|
||||
promptMessages(userPrompt('hello'))
|
||||
),
|
||||
null
|
||||
);
|
||||
});
|
||||
|
||||
test('getActiveProviderMiddleware should merge defaults with profile override', t => {
|
||||
const provider = createProvider({
|
||||
rust: { request: ['clamp_max_tokens'] },
|
||||
@@ -1231,9 +1478,16 @@ test('CopilotProviderFactory should resolve legacy model ids through native regi
|
||||
enableFeature: Sinon.stub(),
|
||||
disableFeature: Sinon.stub(),
|
||||
};
|
||||
const access = {
|
||||
resolveRouteAccess: Sinon.stub().resolves({
|
||||
byokProfiles: [],
|
||||
quotaBackedRoutesAvailable: true,
|
||||
}),
|
||||
};
|
||||
const factory = new CopilotProviderFactory(
|
||||
server as never,
|
||||
registryService as never
|
||||
registryService as never,
|
||||
access as never
|
||||
);
|
||||
factory.register('openai-main', provider);
|
||||
|
||||
@@ -1242,6 +1496,230 @@ test('CopilotProviderFactory should resolve legacy model ids through native regi
|
||||
t.is(provider.resolveModel('gpt-5-2025-08-07')?.id, 'gpt-5');
|
||||
});
|
||||
|
||||
const BYOK_OPENAI_PROFILE: CopilotProviderProfile = {
|
||||
id: 'byok-aaaaaaaaaaaa-openai-server-key1',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 10_000,
|
||||
config: { apiKey: 'byok-key' },
|
||||
};
|
||||
|
||||
const BYOK_FAL_PROFILE: CopilotProviderProfile = {
|
||||
id: 'byok-aaaaaaaaaaaa-fal-server-key1',
|
||||
type: CopilotProviderType.FAL,
|
||||
priority: 10_000,
|
||||
config: { apiKey: 'byok-key' },
|
||||
};
|
||||
|
||||
function createProviderFactoryWithByokRoutes({
|
||||
byokProfiles = [BYOK_OPENAI_PROFILE],
|
||||
hasQuota = true,
|
||||
}: {
|
||||
byokProfiles?: CopilotProviderProfile[];
|
||||
hasQuota?: boolean;
|
||||
} = {}) {
|
||||
const provider = createProvider();
|
||||
const registryService = {
|
||||
getRegistry: () =>
|
||||
buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
id: 'openai-main',
|
||||
type: CopilotProviderType.OpenAI,
|
||||
priority: 1,
|
||||
config: { apiKey: 'test-key' },
|
||||
},
|
||||
],
|
||||
defaults: {},
|
||||
}),
|
||||
};
|
||||
const server = {
|
||||
enableFeature: Sinon.stub(),
|
||||
disableFeature: Sinon.stub(),
|
||||
};
|
||||
const byok = {
|
||||
getProfiles: Sinon.stub().resolves(byokProfiles),
|
||||
};
|
||||
const access = {
|
||||
resolveRouteAccess: Sinon.stub().callsFake(async context => ({
|
||||
byokProfiles: await byok.getProfiles(context),
|
||||
quotaBackedRoutesAvailable: context.quotaBackedRoutesAllowed ?? hasQuota,
|
||||
})),
|
||||
};
|
||||
const factory = new CopilotProviderFactory(
|
||||
server as never,
|
||||
registryService as never,
|
||||
access as never
|
||||
);
|
||||
factory.register('openai-main', provider);
|
||||
|
||||
return { factory, byok };
|
||||
}
|
||||
|
||||
test('CopilotProviderFactory should use matching BYOK routes before quota-backed routes', async t => {
|
||||
const { factory } = createProviderFactoryWithByokRoutes();
|
||||
|
||||
const routes = await factory.resolveRoutes(
|
||||
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
|
||||
{},
|
||||
{ userId: 'user-1', workspaceId: 'workspace-1' }
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
routes.map(route => route.providerId),
|
||||
['byok-aaaaaaaaaaaa-openai-server-key1']
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should skip unsupported BYOK profiles and use quota-backed fallback', async t => {
|
||||
const { factory } = createProviderFactoryWithByokRoutes({
|
||||
byokProfiles: [BYOK_FAL_PROFILE],
|
||||
});
|
||||
|
||||
const routes = await factory.resolveRoutes(
|
||||
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
|
||||
{},
|
||||
{ userId: 'user-1', workspaceId: 'workspace-1' }
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
routes.map(route => route.providerId),
|
||||
['openai-main']
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should resolve BYOK embedding routes with workspace context', async t => {
|
||||
const { factory, byok } = createProviderFactoryWithByokRoutes();
|
||||
|
||||
const routes = await factory.resolveRoutes(
|
||||
{
|
||||
modelId: 'text-embedding-3-small',
|
||||
outputType: ModelOutputType.Embedding,
|
||||
},
|
||||
{},
|
||||
{ workspaceId: 'workspace-1', featureKind: 'workspace_indexing' }
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
routes.map(route => route.providerId),
|
||||
['byok-aaaaaaaaaaaa-openai-server-key1']
|
||||
);
|
||||
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'workspace_indexing',
|
||||
});
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should treat embedding preparation as embedding feature by default', async t => {
|
||||
const { factory, byok } = createProviderFactoryWithByokRoutes();
|
||||
|
||||
await factory.prepareEmbeddingRoutes('text-embedding-3-small', 'hello', {
|
||||
workspace: 'workspace-1',
|
||||
});
|
||||
|
||||
t.true(byok.getProfiles.calledOnce);
|
||||
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'embedding',
|
||||
});
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should resolve BYOK rerank routes before quota-backed routes', async t => {
|
||||
const { factory, byok } = createProviderFactoryWithByokRoutes();
|
||||
|
||||
const preparedRoutes = await factory.prepareRerankRoutes(
|
||||
'gpt-4o-mini',
|
||||
{
|
||||
query: 'programming',
|
||||
candidates: [{ text: 'React is a UI library.' }],
|
||||
},
|
||||
{ workspace: 'workspace-1' }
|
||||
);
|
||||
const resolvedRoutes = await factory.resolveRoutes(
|
||||
{ modelId: 'gpt-4o-mini', outputType: ModelOutputType.Rerank },
|
||||
{},
|
||||
{ workspaceId: 'workspace-1', featureKind: 'rerank' }
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
preparedRoutes.map(route => route.providerId),
|
||||
[]
|
||||
);
|
||||
t.deepEqual(
|
||||
resolvedRoutes.map(route => route.providerId),
|
||||
['byok-aaaaaaaaaaaa-openai-server-key1']
|
||||
);
|
||||
Sinon.assert.calledWithMatch(byok.getProfiles, {
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'rerank',
|
||||
});
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should treat image preparation as image feature by default', async t => {
|
||||
const { factory, byok } = createProviderFactoryWithByokRoutes();
|
||||
|
||||
await factory.prepareImageRoutes(
|
||||
{ modelId: 'gpt-image-1', outputType: ModelOutputType.Image },
|
||||
singleUserPromptMessages('draw a cat'),
|
||||
{ workspace: 'workspace-1' }
|
||||
);
|
||||
|
||||
t.true(byok.getProfiles.calledOnce);
|
||||
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'image',
|
||||
});
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should omit quota-backed routes when quota is exhausted', async t => {
|
||||
const { factory } = createProviderFactoryWithByokRoutes({ hasQuota: false });
|
||||
|
||||
const routes = await factory.resolveRoutes(
|
||||
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
|
||||
{},
|
||||
{ userId: 'user-1', workspaceId: 'workspace-1' }
|
||||
);
|
||||
|
||||
t.deepEqual(
|
||||
routes.map(route => route.providerId),
|
||||
['byok-aaaaaaaaaaaa-openai-server-key1']
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should raise quota exceeded when only quota-backed routes match', async t => {
|
||||
const { factory } = createProviderFactoryWithByokRoutes({
|
||||
byokProfiles: [],
|
||||
hasQuota: false,
|
||||
});
|
||||
|
||||
await t.throwsAsync(
|
||||
factory.resolveRoutes(
|
||||
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
|
||||
{},
|
||||
{ userId: 'user-1', workspaceId: 'workspace-1' }
|
||||
),
|
||||
{ instanceOf: CopilotQuotaExceeded }
|
||||
);
|
||||
});
|
||||
|
||||
test('CopilotProviderFactory should not report quota exhausted when quota-backed routes are disabled', async t => {
|
||||
const { factory } = createProviderFactoryWithByokRoutes({
|
||||
byokProfiles: [],
|
||||
hasQuota: true,
|
||||
});
|
||||
|
||||
const routes = await factory.resolveRoutes(
|
||||
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
|
||||
{},
|
||||
{
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
quotaBackedRoutesAllowed: false,
|
||||
}
|
||||
);
|
||||
|
||||
t.deepEqual(routes, []);
|
||||
});
|
||||
|
||||
test('selectModel should reject unknown models without online fallback', t => {
|
||||
const provider = new TestOpenAIProvider();
|
||||
t.is(provider.resolveModel('online-preview'), undefined);
|
||||
@@ -1476,7 +1954,7 @@ test('ProviderDriverSpec should freeze declarative driver shape', t => {
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should dispatch prepared text routes through native fallback', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
const registry = buildProviderRegistry({
|
||||
profiles: [
|
||||
{
|
||||
@@ -1575,8 +2053,78 @@ test('NativeExecutionEngine should dispatch prepared text routes through native
|
||||
t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes));
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should record single BYOK route dispatch failure', async t => {
|
||||
const byok = {
|
||||
recordProviderFailure: Sinon.stub().resolves(),
|
||||
recordUsage: Sinon.stub().resolves(),
|
||||
};
|
||||
const engine = new NativeExecutionEngine(byok as never);
|
||||
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
|
||||
|
||||
const original = (serverNativeModule as any).llmDispatchPrepared;
|
||||
(serverNativeModule as any).llmDispatchPrepared = () => {
|
||||
throw new Error('401 invalid sk-test-primary');
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmDispatchPrepared = original;
|
||||
});
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
engine.execute({
|
||||
nativeDispatch: {
|
||||
chat: {
|
||||
routes: [
|
||||
nativeRoute({
|
||||
providerId,
|
||||
authToken: 'primary-key',
|
||||
request: nativeTextRequest('hello'),
|
||||
}),
|
||||
],
|
||||
prepared: {
|
||||
route: preparedRoute({
|
||||
providerId,
|
||||
authToken: 'primary-key',
|
||||
}),
|
||||
request: nativeTextRequest('hello'),
|
||||
tools: {},
|
||||
postprocess: { nodeTextMiddleware: [] },
|
||||
},
|
||||
hasTools: false,
|
||||
},
|
||||
},
|
||||
request: {
|
||||
kind: 'text',
|
||||
cond: { modelId: 'gpt-5-mini' },
|
||||
messages: singleUserPromptMessages('hello'),
|
||||
options: {
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
featureKind: 'chat',
|
||||
},
|
||||
},
|
||||
routePolicy: { fallbackOrder: [providerId] },
|
||||
runtimePolicy: {},
|
||||
attachmentPolicy: { materializeRemoteAttachments: true },
|
||||
responsePostprocess: { mode: 'text' },
|
||||
hostPersistence: { persistAssistantTurn: true, outputKind: 'text' },
|
||||
hostContext: {
|
||||
currentMessages: singleUserPromptMessages('hello'),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
t.truthy(error);
|
||||
Sinon.assert.calledOnceWithMatch(byok.recordProviderFailure, {
|
||||
workspaceId: 'workspace-1',
|
||||
providerId,
|
||||
featureKind: 'chat',
|
||||
});
|
||||
Sinon.assert.notCalled(byok.recordUsage);
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should reject single-route plans when no native route is prepared', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
|
||||
const error = await t.throwsAsync(
|
||||
engine.execute({
|
||||
@@ -1604,7 +2152,7 @@ test('NativeExecutionEngine should reject single-route plans when no native rout
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should prefer prepared native fallback dispatch for explicit routes', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
let called = false;
|
||||
|
||||
@@ -1683,7 +2231,7 @@ test('NativeExecutionEngine should prefer prepared native fallback dispatch for
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should stream through prepared native fallback dispatch', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let called = false;
|
||||
|
||||
const original = (serverNativeModule as any).llmDispatchPreparedStream;
|
||||
@@ -1912,7 +2460,7 @@ test('ExecutionPlanBuilder should keep single-route tool chat plans on prepared_
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should route tool-loop chat prepared routes through native dispatch', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
let called = false;
|
||||
let toolCallbackCount = 0;
|
||||
@@ -2262,7 +2810,7 @@ test('ExecutionPlanBuilder should build native prepared routes for structured, i
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should dispatch structured prepared routes through native execution', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
let called = false;
|
||||
|
||||
@@ -2350,7 +2898,7 @@ test('NativeExecutionEngine should dispatch structured prepared routes through n
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should dispatch embedding prepared routes through native execution', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
let called = false;
|
||||
|
||||
@@ -2424,7 +2972,7 @@ test('NativeExecutionEngine should dispatch embedding prepared routes through na
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should dispatch rerank prepared routes through native execution', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
let called = false;
|
||||
|
||||
@@ -2507,7 +3055,7 @@ test('NativeExecutionEngine should dispatch rerank prepared routes through nativ
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should dispatch image plans through prepared native routes', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
let capturedRoutes: unknown;
|
||||
const original = (serverNativeModule as any).llmImageDispatchPrepared;
|
||||
(serverNativeModule as any).llmImageDispatchPrepared = (
|
||||
@@ -2587,8 +3135,90 @@ test('NativeExecutionEngine should dispatch image plans through prepared native
|
||||
t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes));
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should record zero-token BYOK image usage without provider usage', async t => {
|
||||
const byok = {
|
||||
recordUsage: Sinon.stub().resolves(),
|
||||
};
|
||||
const engine = new NativeExecutionEngine(byok as never);
|
||||
const providerId = 'byok-aaaaaaaaaaaa-fal-server-key1';
|
||||
|
||||
const original = (serverNativeModule as any).llmImageDispatchPrepared;
|
||||
(serverNativeModule as any).llmImageDispatchPrepared = () => {
|
||||
return JSON.stringify({
|
||||
provider_id: providerId,
|
||||
response: {
|
||||
images: [
|
||||
{
|
||||
url: 'https://cdn.example.com/image.png',
|
||||
media_type: 'image/png',
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
};
|
||||
t.teardown(() => {
|
||||
(serverNativeModule as any).llmImageDispatchPrepared = original;
|
||||
});
|
||||
|
||||
const request = nativeImageRequest('draw a cat');
|
||||
const imageArtifacts = await collectAsync(
|
||||
engine.executeImageArtifacts({
|
||||
nativeDispatch: {
|
||||
image: {
|
||||
routes: [
|
||||
nativeRoute({
|
||||
providerId,
|
||||
authToken: 'image-key',
|
||||
protocol: 'fal_image',
|
||||
model: 'fal-ai/fast-sdxl',
|
||||
request,
|
||||
}),
|
||||
],
|
||||
prepared: {
|
||||
route: preparedRoute({
|
||||
providerId,
|
||||
authToken: 'image-key',
|
||||
protocol: 'fal_image',
|
||||
model: 'fal-ai/fast-sdxl',
|
||||
}),
|
||||
request,
|
||||
},
|
||||
},
|
||||
},
|
||||
request: {
|
||||
kind: 'image',
|
||||
cond: { modelId: 'fal-ai/fast-sdxl' },
|
||||
messages: singleUserPromptMessages('draw a cat'),
|
||||
options: {
|
||||
workspace: 'workspace-1',
|
||||
user: 'user-1',
|
||||
session: 'session-1',
|
||||
featureKind: 'image',
|
||||
},
|
||||
},
|
||||
routePolicy: { fallbackOrder: [providerId] },
|
||||
runtimePolicy: {},
|
||||
attachmentPolicy: { materializeRemoteAttachments: true },
|
||||
responsePostprocess: { mode: 'image' },
|
||||
hostPersistence: { persistAssistantTurn: true, outputKind: 'image' },
|
||||
hostContext: {},
|
||||
})
|
||||
);
|
||||
|
||||
t.is(imageArtifacts.length, 1);
|
||||
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
|
||||
workspaceId: 'workspace-1',
|
||||
userId: 'user-1',
|
||||
sessionId: 'session-1',
|
||||
featureKind: 'image',
|
||||
providerId,
|
||||
model: 'fal-ai/fast-sdxl',
|
||||
usage: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
test('NativeExecutionEngine should reject image plans without native dispatch', async t => {
|
||||
const engine = new NativeExecutionEngine();
|
||||
const engine = createNativeExecutionEngine();
|
||||
|
||||
await t.throwsAsync(
|
||||
collectAsync(
|
||||
|
||||
@@ -534,6 +534,58 @@ test('doc_semantic_search should return empty array when nothing matches', async
|
||||
t.deepEqual(result, []);
|
||||
});
|
||||
|
||||
test('doc_semantic_search should pass BYOK route context into embedding matches', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
workspace: () => ({
|
||||
can: async () => true,
|
||||
docs: async () => [],
|
||||
}),
|
||||
}),
|
||||
} as unknown as AccessController;
|
||||
|
||||
const models = {
|
||||
workspace: {
|
||||
get: async () => ({ id: 'workspace-1' }),
|
||||
},
|
||||
} as unknown as Models;
|
||||
|
||||
let workspaceRouteContext: unknown;
|
||||
let sessionRouteContext: unknown;
|
||||
const contextService = {
|
||||
matchWorkspaceAll: async (...args: unknown[]) => {
|
||||
workspaceRouteContext = args[7];
|
||||
return [];
|
||||
},
|
||||
getBySessionId: async () => ({
|
||||
matchFiles: async (...args: unknown[]) => {
|
||||
sessionRouteContext = args[5];
|
||||
return [];
|
||||
},
|
||||
}),
|
||||
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
|
||||
|
||||
const semanticTool = createDocSemanticSearchTool(
|
||||
buildDocSearchGetter(ac, contextService, 'session-1', models).bind(null, {
|
||||
user: 'user-1',
|
||||
workspace: 'workspace-1',
|
||||
byokLeaseId: 'lease-1',
|
||||
})
|
||||
);
|
||||
|
||||
const result = await semanticTool.execute?.({ query: 'hello' }, {});
|
||||
|
||||
t.deepEqual(result, []);
|
||||
t.deepEqual(workspaceRouteContext, {
|
||||
userId: 'user-1',
|
||||
byokLeaseId: 'lease-1',
|
||||
});
|
||||
t.deepEqual(sessionRouteContext, {
|
||||
userId: 'user-1',
|
||||
byokLeaseId: 'lease-1',
|
||||
});
|
||||
});
|
||||
|
||||
test('blob_read should return explicit error when attachment context is missing', async t => {
|
||||
const ac = {
|
||||
user: () => ({
|
||||
|
||||
@@ -215,7 +215,7 @@ test('settleTask checks copilot quota before unlocking ready task', async t => {
|
||||
status: 'settled',
|
||||
protectedResult: payload,
|
||||
});
|
||||
const checkQuota = Sinon.stub().rejects(new Error('quota exceeded'));
|
||||
const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
|
||||
const service = new CopilotTranscriptionService(
|
||||
{
|
||||
copilotTranscriptTask: {
|
||||
@@ -232,14 +232,18 @@ test('settleTask checks copilot quota before unlocking ready task', async t => {
|
||||
{} as never,
|
||||
{} as never,
|
||||
{} as never,
|
||||
{ checkQuota } as never
|
||||
{ assertQuotaOrByok } as never
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
() => service.settleTask('user-1', 'workspace-1', 'task-1'),
|
||||
{ message: /quota exceeded/ }
|
||||
);
|
||||
Sinon.assert.calledOnceWithExactly(checkQuota, 'user-1');
|
||||
Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
Sinon.assert.notCalled(settle);
|
||||
});
|
||||
|
||||
@@ -341,6 +345,48 @@ test('retryTask reuses failed task and queues a new action attempt', async t =>
|
||||
Sinon.assert.calledOnceWithExactly(markRunning, 'task-1');
|
||||
});
|
||||
|
||||
test('retryTask prechecks quota or BYOK before queueing provider work', async t => {
|
||||
const add = Sinon.stub().resolves(undefined);
|
||||
const markRunning = Sinon.stub().resolves({ id: 'task-1' });
|
||||
const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
|
||||
const payload = TranscriptPayloadSchema.parse({
|
||||
normalizedTranscript: '00:00:05 A: Kickoff',
|
||||
});
|
||||
const service = new CopilotTranscriptionService(
|
||||
{
|
||||
copilotTranscriptTask: {
|
||||
getWithUser: Sinon.stub().resolves({
|
||||
id: 'task-1',
|
||||
status: 'failed',
|
||||
strategy: 'gemini',
|
||||
protectedResult: payload,
|
||||
}),
|
||||
markRunning,
|
||||
},
|
||||
} as never,
|
||||
{ add } as never,
|
||||
{} as never,
|
||||
{
|
||||
resolveTranscriptionModel: Sinon.stub().resolves('gemini-2.5-flash'),
|
||||
} as never,
|
||||
{} as never,
|
||||
{} as never,
|
||||
{ assertQuotaOrByok } as never
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
() => service.retryTask('user-1', 'workspace-1', 'task-1'),
|
||||
{ message: /quota exceeded/ }
|
||||
);
|
||||
Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
Sinon.assert.notCalled(add);
|
||||
Sinon.assert.notCalled(markRunning);
|
||||
});
|
||||
|
||||
for (const status of ['ready', 'settled']) {
|
||||
test(`submitTask allows a new task for the same blob after ${status} task`, async t => {
|
||||
const createdTasks: unknown[] = [];
|
||||
@@ -390,6 +436,37 @@ for (const status of ['ready', 'settled']) {
|
||||
});
|
||||
}
|
||||
|
||||
test('submitTask prechecks quota or BYOK before persisting uploads', async t => {
|
||||
const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
|
||||
const resolveTranscriptionModel = Sinon.stub().resolves('gemini-2.5-flash');
|
||||
const service = new CopilotTranscriptionService(
|
||||
{
|
||||
copilotTranscriptTask: {
|
||||
getWithUser: Sinon.stub().resolves(null),
|
||||
},
|
||||
} as never,
|
||||
{} as never,
|
||||
{} as never,
|
||||
{
|
||||
resolveTranscriptionModel,
|
||||
} as never,
|
||||
{} as never,
|
||||
{} as never,
|
||||
{ assertQuotaOrByok } as never
|
||||
);
|
||||
|
||||
await t.throwsAsync(
|
||||
() => service.submitTask('user-1', 'workspace-1', 'blob-1', []),
|
||||
{ message: /quota exceeded/ }
|
||||
);
|
||||
Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
|
||||
userId: 'user-1',
|
||||
workspaceId: 'workspace-1',
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
Sinon.assert.notCalled(resolveTranscriptionModel);
|
||||
});
|
||||
|
||||
test('submitTask rejects unavailable transcript strategy', async t => {
|
||||
const service = new CopilotTranscriptionService(
|
||||
{
|
||||
|
||||
@@ -1145,6 +1145,110 @@ test('should count action runs without double-counting legacy action sessions',
|
||||
t.truthy(legacyAction.sessionId);
|
||||
});
|
||||
|
||||
test('should exclude BYOK provider usage from copilot quota cost', async t => {
|
||||
const { copilotSession, db, models } = t.context;
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
const regular = await createTestSession(t);
|
||||
const firstMessage = await copilotSession.appendMessage({
|
||||
sessionId: regular.sessionId,
|
||||
userId: user.id,
|
||||
prompt: { model: 'gpt-5-mini' },
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'regular message',
|
||||
createdAt: new Date(),
|
||||
},
|
||||
});
|
||||
const secondMessage = await copilotSession.appendMessage({
|
||||
sessionId: regular.sessionId,
|
||||
userId: user.id,
|
||||
prompt: { model: 'gpt-5-mini' },
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'second BYOK message',
|
||||
createdAt: new Date(),
|
||||
},
|
||||
});
|
||||
await copilotSession.appendMessage({
|
||||
sessionId: regular.sessionId,
|
||||
userId: user.id,
|
||||
prompt: { model: 'gpt-5-mini' },
|
||||
message: {
|
||||
role: 'user',
|
||||
content: 'quota-backed message',
|
||||
createdAt: new Date(),
|
||||
},
|
||||
});
|
||||
const failedRun = await models.copilotActionRun.create({
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
actionId: 'mindmap.generate',
|
||||
actionVersion: 'v1',
|
||||
});
|
||||
await models.copilotActionRun.complete(failedRun.id, {
|
||||
status: 'failed',
|
||||
errorCode: 'test_failed',
|
||||
});
|
||||
const pendingTranscriptTask = await models.copilotTranscriptTask.create({
|
||||
userId: user.id,
|
||||
workspaceId: workspace.id,
|
||||
blobId: 'pending-audio',
|
||||
strategy: 'gemini',
|
||||
recipeId: 'transcript.audio.gemini',
|
||||
recipeVersion: 'v1',
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'openai',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'chat',
|
||||
billingUnitId: firstMessage.id,
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'openai',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'chat',
|
||||
billingUnitId: firstMessage.id,
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'fal',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'image',
|
||||
billingUnitId: secondMessage.id,
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'fal',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'image',
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'openai',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'action',
|
||||
billingUnitId: failedRun.id,
|
||||
});
|
||||
await models.copilotUsage.create({
|
||||
workspaceId: workspace.id,
|
||||
userId: user.id,
|
||||
provider: 'gemini',
|
||||
providerSource: 'byok_server',
|
||||
featureKind: 'transcript',
|
||||
billingUnitId: pendingTranscriptTask.id,
|
||||
});
|
||||
|
||||
t.is(await copilotSession.countUserMessages(user.id), 1);
|
||||
});
|
||||
|
||||
test('should get sessions for title generation correctly', async t => {
|
||||
const { copilotSession, db } = t.context;
|
||||
await createTestPrompts(copilotSession, db);
|
||||
|
||||
@@ -5,8 +5,9 @@ import ava, { ExecutionContext, TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
import { Doc as YDoc } from 'yjs';
|
||||
|
||||
import { MockEventBus } from '../../../__tests__/mocks';
|
||||
import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
|
||||
import { ConfigFactory } from '../../../base';
|
||||
import { ConfigFactory, EventBus } from '../../../base';
|
||||
import { Flavor } from '../../../env';
|
||||
import { Models } from '../../../models';
|
||||
import { DocReader, PgWorkspaceDocStorageAdapter } from '../../doc';
|
||||
@@ -16,6 +17,7 @@ interface Context {
|
||||
app: TestingApp;
|
||||
adapter: PgWorkspaceDocStorageAdapter;
|
||||
docReader: DocReader;
|
||||
recordDocView: Sinon.SinonStub;
|
||||
}
|
||||
|
||||
const test = ava as TestFn<Context>;
|
||||
@@ -23,7 +25,9 @@ const test = ava as TestFn<Context>;
|
||||
test.before(async t => {
|
||||
// @ts-expect-error testing
|
||||
env.FLAVOR = Flavor.Renderer;
|
||||
const app = await createTestingApp();
|
||||
const app = await createTestingApp({
|
||||
tapModule: m => m.overrideProvider(EventBus).useClass(MockEventBus),
|
||||
});
|
||||
|
||||
t.context.models = app.get(Models);
|
||||
t.context.adapter = app.get(PgWorkspaceDocStorageAdapter);
|
||||
@@ -45,6 +49,14 @@ test.beforeEach(async t => {
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
workspace = await t.context.models.workspace.create(user.id);
|
||||
t.context.recordDocView = Sinon.stub(
|
||||
t.context.models.workspaceAnalytics,
|
||||
'recordDocView'
|
||||
).resolves();
|
||||
});
|
||||
|
||||
test.afterEach.always(t => {
|
||||
t.context.recordDocView?.restore();
|
||||
});
|
||||
|
||||
test.after.always(async t => {
|
||||
@@ -88,10 +100,7 @@ test('should record page view when rendering shared page', async t => {
|
||||
title: 'analytics-doc',
|
||||
summary: 'summary',
|
||||
});
|
||||
const record = Sinon.stub(
|
||||
models.workspaceAnalytics,
|
||||
'recordDocView'
|
||||
).resolves();
|
||||
const record = t.context.recordDocView;
|
||||
|
||||
await app.GET(`/workspace/${workspace.id}/${docId}`).expect(200);
|
||||
|
||||
@@ -103,7 +112,6 @@ test('should record page view when rendering shared page', async t => {
|
||||
});
|
||||
|
||||
docContent.restore();
|
||||
record.restore();
|
||||
});
|
||||
|
||||
const policyCases: Array<{
|
||||
@@ -146,10 +154,7 @@ const policyCases: Array<{
|
||||
unknownBlocks: [],
|
||||
}),
|
||||
docContent: Sinon.stub(docReader, 'getDocContent'),
|
||||
record: Sinon.stub(
|
||||
models.workspaceAnalytics,
|
||||
'recordDocView'
|
||||
).resolves(),
|
||||
record: models.workspaceAnalytics.recordDocView as Sinon.SinonStub,
|
||||
};
|
||||
},
|
||||
request: (app, docId) =>
|
||||
|
||||
139
packages/backend/server/src/models/copilot-byok.ts
Normal file
139
packages/backend/server/src/models/copilot-byok.ts
Normal file
@@ -0,0 +1,139 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
|
||||
import { BaseModel } from './base';
|
||||
|
||||
export type UpsertAiWorkspaceByokConfigInput = {
|
||||
id?: string | null;
|
||||
workspaceId: string;
|
||||
provider: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
encryptedApiKey?: string;
|
||||
endpoint: string | null;
|
||||
sortOrder: number;
|
||||
enabled: boolean;
|
||||
userId?: string;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class CopilotWorkspaceByokConfigModel extends BaseModel {
|
||||
async list(workspaceId: string) {
|
||||
return await this.db.aiWorkspaceByokConfig.findMany({
|
||||
where: { workspaceId },
|
||||
orderBy: [{ sortOrder: 'asc' }, { createdAt: 'asc' }],
|
||||
});
|
||||
}
|
||||
|
||||
async listEnabled(workspaceId: string) {
|
||||
return await this.db.aiWorkspaceByokConfig.findMany({
|
||||
where: { workspaceId, enabled: true },
|
||||
orderBy: [{ sortOrder: 'asc' }, { createdAt: 'asc' }],
|
||||
});
|
||||
}
|
||||
|
||||
async get(id: string) {
|
||||
return await this.db.aiWorkspaceByokConfig.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async upsert(input: UpsertAiWorkspaceByokConfigInput) {
|
||||
const data = {
|
||||
provider: input.provider,
|
||||
name: input.name,
|
||||
description: input.description,
|
||||
endpoint: input.endpoint,
|
||||
sortOrder: input.sortOrder,
|
||||
enabled: input.enabled,
|
||||
updatedBy: input.userId,
|
||||
...(input.encryptedApiKey
|
||||
? {
|
||||
encryptedApiKey: input.encryptedApiKey,
|
||||
lastValidatedAt: new Date(),
|
||||
lastValidationError: null,
|
||||
disabledReason: null,
|
||||
lastError: null,
|
||||
lastErrorAt: null,
|
||||
}
|
||||
: {}),
|
||||
};
|
||||
|
||||
return input.id
|
||||
? await this.db.aiWorkspaceByokConfig.update({
|
||||
where: { id: input.id, workspaceId: input.workspaceId },
|
||||
data,
|
||||
})
|
||||
: await this.db.aiWorkspaceByokConfig.create({
|
||||
data: {
|
||||
...data,
|
||||
encryptedApiKey: input.encryptedApiKey ?? '',
|
||||
workspaceId: input.workspaceId,
|
||||
createdBy: input.userId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async reorder(workspaceId: string, ids: string[], userId?: string) {
|
||||
await Promise.all(
|
||||
ids.map((id, sortOrder) =>
|
||||
this.db.aiWorkspaceByokConfig.update({
|
||||
where: { id, workspaceId },
|
||||
data: { sortOrder, updatedBy: userId },
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async delete(workspaceId: string, id: string) {
|
||||
await this.db.aiWorkspaceByokConfig.delete({ where: { id, workspaceId } });
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async clear(workspaceId: string, provider?: string | null) {
|
||||
await this.db.aiWorkspaceByokConfig.deleteMany({
|
||||
where: { workspaceId, ...(provider ? { provider } : {}) },
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async markValidated(workspaceId: string, id: string, userId?: string) {
|
||||
await this.db.aiWorkspaceByokConfig.update({
|
||||
where: { id, workspaceId },
|
||||
data: {
|
||||
enabled: true,
|
||||
disabledReason: null,
|
||||
lastValidatedAt: new Date(),
|
||||
lastValidationError: null,
|
||||
lastError: null,
|
||||
lastErrorAt: null,
|
||||
updatedBy: userId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async markFailure(workspaceId: string, id: string, message: string) {
|
||||
await this.db.aiWorkspaceByokConfig.update({
|
||||
where: { id, workspaceId },
|
||||
data: {
|
||||
enabled: false,
|
||||
disabledReason: 'recent_failure',
|
||||
lastValidationError: message,
|
||||
lastError: message,
|
||||
lastErrorAt: new Date(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async touchUsed(workspaceId: string, id: string) {
|
||||
await this.db.aiWorkspaceByokConfig.updateMany({
|
||||
where: { id, workspaceId },
|
||||
data: { lastUsedAt: new Date() },
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1005,20 +1005,26 @@ export class CopilotSessionModel extends BaseModel {
|
||||
.filter(({ promptAction }) => !promptAction)
|
||||
.map(({ messageCost }) => messageCost)
|
||||
.reduce((prev, cost) => prev + cost, 0);
|
||||
const [actionRunCost, legacyActionSessionCost, transcriptSettlementCost] =
|
||||
await Promise.all([
|
||||
const [
|
||||
actionRunCost,
|
||||
legacyActionSessionCost,
|
||||
transcriptSettlementCost,
|
||||
byokQuotaExemptCost,
|
||||
] = await Promise.all([
|
||||
this.models.copilotActionRun.countSucceededByUser(userId),
|
||||
this.models.copilotActionRun.countLegacyPromptActionSessionsWithoutRun(
|
||||
userId
|
||||
),
|
||||
this.models.copilotTranscriptTask.countSettledByUser(userId),
|
||||
this.models.copilotUsage.countQuotaExemptByokUsage(userId),
|
||||
]);
|
||||
return (
|
||||
const quotaBackedCost =
|
||||
regularMessageCost +
|
||||
actionRunCost +
|
||||
legacyActionSessionCost +
|
||||
transcriptSettlementCost
|
||||
);
|
||||
transcriptSettlementCost -
|
||||
byokQuotaExemptCost;
|
||||
return Math.max(0, quotaBackedCost);
|
||||
}
|
||||
|
||||
async cleanupEmptySessions(earlyThen: Date) {
|
||||
|
||||
142
packages/backend/server/src/models/copilot-usage.ts
Normal file
142
packages/backend/server/src/models/copilot-usage.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { Prisma } from '@prisma/client';
|
||||
|
||||
import { BaseModel } from './base';
|
||||
|
||||
type CreateAiUsageEventInput = {
|
||||
workspaceId: string;
|
||||
userId?: string;
|
||||
provider: string;
|
||||
providerSource: string;
|
||||
featureKind: string;
|
||||
model?: string | null;
|
||||
sessionId?: string;
|
||||
taskId?: string;
|
||||
actionId?: string;
|
||||
billingUnitId?: string;
|
||||
promptTokens?: number;
|
||||
completionTokens?: number;
|
||||
totalTokens?: number;
|
||||
cachedTokens?: number;
|
||||
};
|
||||
|
||||
type UsageAggregateRow = {
|
||||
date: string;
|
||||
featureKind: string;
|
||||
totalTokens: number | bigint | null;
|
||||
};
|
||||
|
||||
type CountRow = {
|
||||
count: number | bigint;
|
||||
};
|
||||
|
||||
const BYOK_PROVIDER_SOURCES = ['byok_server', 'byok_local'];
|
||||
const QUOTA_EXEMPT_BYOK_FEATURES = ['chat', 'action', 'image', 'transcript'];
|
||||
|
||||
@Injectable()
|
||||
export class CopilotUsageModel extends BaseModel {
|
||||
@Transactional()
|
||||
async create(input: CreateAiUsageEventInput) {
|
||||
await this.db.aiUsageEvent.create({
|
||||
data: {
|
||||
workspaceId: input.workspaceId,
|
||||
userId: input.userId,
|
||||
provider: input.provider,
|
||||
providerSource: input.providerSource,
|
||||
featureKind: input.featureKind,
|
||||
model: input.model ?? null,
|
||||
sessionId: input.sessionId,
|
||||
taskId: input.taskId,
|
||||
actionId: input.actionId,
|
||||
billingUnitId: input.billingUnitId,
|
||||
promptTokens: input.promptTokens ?? 0,
|
||||
completionTokens: input.completionTokens ?? 0,
|
||||
totalTokens: input.totalTokens ?? 0,
|
||||
cachedTokens: input.cachedTokens ?? 0,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async countQuotaExemptByokUsage(userId: string) {
|
||||
const rows = await this.db.$queryRaw<CountRow[]>(Prisma.sql`
|
||||
WITH "byok_usage" AS (
|
||||
SELECT "billing_unit_id", "feature_kind"
|
||||
FROM "ai_usage_events"
|
||||
WHERE "user_id" = ${userId}
|
||||
AND "provider_source" IN (${Prisma.join(BYOK_PROVIDER_SOURCES)})
|
||||
AND "feature_kind" IN (${Prisma.join(QUOTA_EXEMPT_BYOK_FEATURES)})
|
||||
AND "billing_unit_id" IS NOT NULL
|
||||
),
|
||||
"message_units" AS (
|
||||
SELECT DISTINCT "usage"."billing_unit_id"
|
||||
FROM "byok_usage" AS "usage"
|
||||
JOIN "ai_sessions_messages" AS "message"
|
||||
ON "message"."id" = "usage"."billing_unit_id"
|
||||
JOIN "ai_sessions_metadata" AS "session"
|
||||
ON "session"."id" = "message"."session_id"
|
||||
WHERE "usage"."feature_kind" IN ('chat', 'action', 'image')
|
||||
AND "message"."role" = 'user'
|
||||
AND "session"."user_id" = ${userId}
|
||||
AND ("session"."prompt_action" IS NULL OR "session"."prompt_action" = '')
|
||||
),
|
||||
"action_units" AS (
|
||||
SELECT DISTINCT "usage"."billing_unit_id"
|
||||
FROM "byok_usage" AS "usage"
|
||||
JOIN "ai_action_runs" AS "run"
|
||||
ON "run"."id" = "usage"."billing_unit_id"
|
||||
WHERE "usage"."feature_kind" IN ('action', 'image')
|
||||
AND "run"."user_id" = ${userId}
|
||||
AND "run"."status" = 'succeeded'
|
||||
AND "run"."action_id" NOT LIKE 'transcript.audio.%'
|
||||
),
|
||||
"transcript_units" AS (
|
||||
SELECT DISTINCT "usage"."billing_unit_id"
|
||||
FROM "byok_usage" AS "usage"
|
||||
JOIN "ai_transcript_tasks" AS "task"
|
||||
ON "task"."id" = "usage"."billing_unit_id"
|
||||
WHERE "usage"."feature_kind" = 'transcript'
|
||||
AND "task"."user_id" = ${userId}
|
||||
AND "task"."status" = 'settled'
|
||||
)
|
||||
SELECT (
|
||||
(SELECT COUNT(*) FROM "message_units") +
|
||||
(SELECT COUNT(*) FROM "action_units") +
|
||||
(SELECT COUNT(*) FROM "transcript_units")
|
||||
) AS "count"
|
||||
`);
|
||||
const count = rows[0]?.count ?? 0;
|
||||
return typeof count === 'bigint' ? Number(count) : count;
|
||||
}
|
||||
|
||||
async aggregateByDay(input: {
|
||||
workspaceId: string;
|
||||
from: Date;
|
||||
to: Date;
|
||||
providerSources: string[];
|
||||
}) {
|
||||
if (!input.providerSources.length) return [];
|
||||
|
||||
const rows = await this.db.$queryRaw<UsageAggregateRow[]>(Prisma.sql`
|
||||
SELECT
|
||||
to_char(date_trunc('day', "created_at" AT TIME ZONE 'UTC'), 'YYYY-MM-DD') AS "date",
|
||||
"feature_kind" AS "featureKind",
|
||||
COALESCE(SUM("total_tokens"), 0)::bigint AS "totalTokens"
|
||||
FROM "ai_usage_events"
|
||||
WHERE "workspace_id" = ${input.workspaceId}
|
||||
AND "provider_source" IN (${Prisma.join(input.providerSources)})
|
||||
AND "created_at" >= ${input.from}
|
||||
AND "created_at" < ${input.to}
|
||||
GROUP BY 1, 2
|
||||
ORDER BY 1 ASC, 2 ASC
|
||||
`);
|
||||
|
||||
return rows.map(row => {
|
||||
return {
|
||||
date: new Date(`${row.date}T00:00:00.000Z`),
|
||||
featureKind: row.featureKind,
|
||||
totalTokens: Number(row.totalTokens ?? 0),
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -17,10 +17,12 @@ import { CommentModel } from './comment';
|
||||
import { CommentAttachmentModel } from './comment-attachment';
|
||||
import { AppConfigModel } from './config';
|
||||
import { CopilotActionRunModel } from './copilot-action-run';
|
||||
import { CopilotWorkspaceByokConfigModel } from './copilot-byok';
|
||||
import { CopilotContextModel } from './copilot-context';
|
||||
import { CopilotJobModel } from './copilot-job';
|
||||
import { CopilotSessionModel } from './copilot-session';
|
||||
import { CopilotTranscriptTaskModel } from './copilot-transcript-task';
|
||||
import { CopilotUsageModel } from './copilot-usage';
|
||||
import { CopilotWorkspaceConfigModel } from './copilot-workspace';
|
||||
import { DocModel } from './doc';
|
||||
import { DocUserModel } from './doc-user';
|
||||
@@ -58,10 +60,12 @@ const MODELS = {
|
||||
notification: NotificationModel,
|
||||
userSettings: UserSettingsModel,
|
||||
copilotSession: CopilotSessionModel,
|
||||
copilotUsage: CopilotUsageModel,
|
||||
copilotTranscriptTask: CopilotTranscriptTaskModel,
|
||||
copilotActionRun: CopilotActionRunModel,
|
||||
copilotContext: CopilotContextModel,
|
||||
copilotWorkspace: CopilotWorkspaceConfigModel,
|
||||
copilotWorkspaceByokConfig: CopilotWorkspaceByokConfigModel,
|
||||
copilotJob: CopilotJobModel,
|
||||
appConfig: AppConfigModel,
|
||||
comment: CommentModel,
|
||||
@@ -133,10 +137,12 @@ export * from './calendar-subscription';
|
||||
export * from './comment';
|
||||
export * from './comment-attachment';
|
||||
export * from './common';
|
||||
export * from './copilot-byok';
|
||||
export * from './copilot-context';
|
||||
export * from './copilot-job';
|
||||
export * from './copilot-session';
|
||||
export * from './copilot-transcript-task';
|
||||
export * from './copilot-usage';
|
||||
export * from './copilot-workspace';
|
||||
export * from './doc';
|
||||
export * from './doc-user';
|
||||
|
||||
@@ -458,6 +458,7 @@ type LlmRerankResponse = {
|
||||
|
||||
export type LlmToolLoopStreamEvent =
|
||||
| { type: 'message_start'; id?: string; model?: string }
|
||||
| { type: 'provider_selected'; provider_id: string }
|
||||
| { type: 'text_delta'; text: string }
|
||||
| { type: 'reasoning_delta'; text: string }
|
||||
| {
|
||||
@@ -537,7 +538,14 @@ function parseLlmEventJson(eventJson: string): LlmStreamEvent {
|
||||
function parseLlmToolLoopStreamEvent(
|
||||
eventJson: string
|
||||
): LlmToolLoopStreamEvent {
|
||||
return parseToolLoopStreamEvent(parseLlmEventJson(eventJson));
|
||||
const event = parseLlmEventJson(eventJson);
|
||||
if (
|
||||
event.type === 'provider_selected' &&
|
||||
typeof event.provider_id === 'string'
|
||||
) {
|
||||
return event;
|
||||
}
|
||||
return parseToolLoopStreamEvent(event);
|
||||
}
|
||||
|
||||
export function llmMatchModelCapabilities(
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import type { ByokFeatureKind } from '../byok/types';
|
||||
|
||||
export type ByokSourceCoverage = {
|
||||
local: boolean;
|
||||
server: boolean;
|
||||
};
|
||||
|
||||
export type CopilotFeatureAccessRule = ByokSourceCoverage & {
|
||||
quotaMetered: boolean;
|
||||
};
|
||||
|
||||
const DEFAULT_BYOK_COVERAGE: ByokSourceCoverage = {
|
||||
local: true,
|
||||
server: true,
|
||||
};
|
||||
|
||||
const DEFAULT_FEATURE_ACCESS: CopilotFeatureAccessRule = {
|
||||
...DEFAULT_BYOK_COVERAGE,
|
||||
quotaMetered: true,
|
||||
};
|
||||
|
||||
const COPILOT_FEATURE_ACCESS: Partial<
|
||||
Record<ByokFeatureKind, CopilotFeatureAccessRule>
|
||||
> = {
|
||||
transcript: { local: false, server: true, quotaMetered: true },
|
||||
embedding: { local: false, server: true, quotaMetered: false },
|
||||
workspace_indexing: { local: false, server: true, quotaMetered: false },
|
||||
rerank: { local: false, server: true, quotaMetered: false },
|
||||
};
|
||||
|
||||
export function getByokSourceCoverage(
|
||||
featureKind?: ByokFeatureKind
|
||||
): ByokSourceCoverage {
|
||||
const access = getCopilotFeatureAccess(featureKind);
|
||||
return { local: access.local, server: access.server };
|
||||
}
|
||||
|
||||
export function getCopilotFeatureAccess(
|
||||
featureKind?: ByokFeatureKind
|
||||
): CopilotFeatureAccessRule {
|
||||
return featureKind
|
||||
? (COPILOT_FEATURE_ACCESS[featureKind] ?? DEFAULT_FEATURE_ACCESS)
|
||||
: DEFAULT_FEATURE_ACCESS;
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
export * from './feature-coverage';
|
||||
export * from './policy';
|
||||
106
packages/backend/server/src/plugins/copilot/access/policy.ts
Normal file
106
packages/backend/server/src/plugins/copilot/access/policy.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { CopilotQuotaExceeded } from '../../../base';
|
||||
import { ByokService } from '../byok/service';
|
||||
import type { ByokFeatureKind } from '../byok/types';
|
||||
import type { CopilotProviderProfile } from '../config';
|
||||
import { ConversationPolicy } from '../conversation/policy';
|
||||
import {
|
||||
getByokSourceCoverage,
|
||||
getCopilotFeatureAccess,
|
||||
} from './feature-coverage';
|
||||
|
||||
export type CopilotAccessContext = {
|
||||
userId?: string;
|
||||
workspaceId?: string;
|
||||
byokLeaseId?: string;
|
||||
featureKind?: ByokFeatureKind;
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
export type CopilotRouteAccess = {
|
||||
byokProfiles: CopilotProviderProfile[];
|
||||
quotaBackedRoutesAvailable: boolean;
|
||||
};
|
||||
|
||||
export type CopilotTurnRouteAccess = {
|
||||
byokProfiles: CopilotProviderProfile[];
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class CopilotAccessPolicy {
|
||||
constructor(
|
||||
private readonly conversationPolicy: ConversationPolicy,
|
||||
private readonly byok: ByokService
|
||||
) {}
|
||||
|
||||
async getByokProfiles(context: CopilotAccessContext = {}) {
|
||||
const coverage = getByokSourceCoverage(context.featureKind);
|
||||
return await this.byok.getProfiles(context, coverage);
|
||||
}
|
||||
|
||||
async canUseQuotaBackedRoutes(context: CopilotAccessContext = {}) {
|
||||
if (context.quotaBackedRoutesAllowed !== undefined) {
|
||||
return context.quotaBackedRoutesAllowed;
|
||||
}
|
||||
if (!getCopilotFeatureAccess(context.featureKind).quotaMetered) {
|
||||
return true;
|
||||
}
|
||||
if (!context.userId) {
|
||||
return true;
|
||||
}
|
||||
return await this.conversationPolicy.hasQuota(context.userId);
|
||||
}
|
||||
|
||||
async getQuota(userId: string) {
|
||||
return await this.conversationPolicy.getQuota(userId);
|
||||
}
|
||||
|
||||
async checkQuota(userId: string) {
|
||||
await this.conversationPolicy.checkQuota(userId);
|
||||
}
|
||||
|
||||
async resolveRouteAccess(
|
||||
context: CopilotAccessContext = {}
|
||||
): Promise<CopilotRouteAccess> {
|
||||
const [byokProfiles, quotaBackedRoutesAvailable] = await Promise.all([
|
||||
this.getByokProfiles(context),
|
||||
this.canUseQuotaBackedRoutes(context),
|
||||
]);
|
||||
|
||||
return { byokProfiles, quotaBackedRoutesAvailable };
|
||||
}
|
||||
|
||||
async resolveTurnRouteAccess(
|
||||
context: CopilotAccessContext
|
||||
): Promise<CopilotTurnRouteAccess> {
|
||||
const byokProfiles = await this.getByokProfiles(context);
|
||||
if (context.quotaBackedRoutesAllowed === false) {
|
||||
return { byokProfiles, quotaBackedRoutesAllowed: false };
|
||||
}
|
||||
const featureAccess = getCopilotFeatureAccess(context.featureKind);
|
||||
if (!byokProfiles.length && context.userId && featureAccess.quotaMetered) {
|
||||
await this.conversationPolicy.checkQuota(context.userId);
|
||||
}
|
||||
|
||||
const quotaBackedRoutesAllowed = byokProfiles.length
|
||||
? context.quotaBackedRoutesAllowed
|
||||
: true;
|
||||
return { byokProfiles, quotaBackedRoutesAllowed };
|
||||
}
|
||||
|
||||
async assertQuotaOrByok(context: CopilotAccessContext) {
|
||||
const byokProfiles = await this.getByokProfiles(context);
|
||||
if (context.quotaBackedRoutesAllowed === false) {
|
||||
if (!byokProfiles.length) {
|
||||
throw new CopilotQuotaExceeded();
|
||||
}
|
||||
return;
|
||||
}
|
||||
const featureAccess = getCopilotFeatureAccess(context.featureKind);
|
||||
if (!byokProfiles.length && context.userId && featureAccess.quotaMetered) {
|
||||
await this.conversationPolicy.checkQuota(context.userId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
export { ByokEntitlementPolicy } from './policy';
|
||||
export { WorkspaceByokResolver } from './resolver';
|
||||
export { type ByokProviderRequestContext, ByokService } from './service';
|
||||
export * from './types';
|
||||
105
packages/backend/server/src/plugins/copilot/byok/policy.ts
Normal file
105
packages/backend/server/src/plugins/copilot/byok/policy.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { ActionForbidden } from '../../../base';
|
||||
import { Models, WorkspaceRole } from '../../../models';
|
||||
|
||||
@Injectable()
|
||||
export class ByokEntitlementPolicy {
|
||||
constructor(private readonly models: Models) {}
|
||||
|
||||
private isUserPlanEntitled(features: string[]) {
|
||||
return (
|
||||
features.includes('pro_plan_v1') ||
|
||||
features.includes('lifetime_pro_plan_v1') ||
|
||||
features.includes('unlimited_copilot')
|
||||
);
|
||||
}
|
||||
|
||||
async hasAiPlan(userId?: string) {
|
||||
if (!userId) return false;
|
||||
const features = await this.models.userFeature.list(userId);
|
||||
return this.isUserPlanEntitled(features);
|
||||
}
|
||||
|
||||
async hasManagementAccess(workspaceId: string, userId?: string) {
|
||||
if (!userId) return false;
|
||||
const role = await this.models.workspaceUser.getActive(workspaceId, userId);
|
||||
return (
|
||||
role?.type === WorkspaceRole.Owner || role?.type === WorkspaceRole.Admin
|
||||
);
|
||||
}
|
||||
|
||||
async assertManagementAccess(workspaceId: string, userId?: string) {
|
||||
if (!(await this.hasManagementAccess(workspaceId, userId))) {
|
||||
throw new ActionForbidden(
|
||||
'BYOK settings require workspace owner or admin.'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async getWorkspaceOwnerId(workspaceId: string) {
|
||||
const workspace = await this.models.workspace.get(workspaceId);
|
||||
if (!workspace) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
return (await this.models.workspaceUser.getOwner(workspaceId)).id;
|
||||
} catch (error) {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message === 'Workspace owner not found'
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async hasLocalEntitlement(workspaceId: string, userId?: string) {
|
||||
if (env.selfhosted) return true;
|
||||
|
||||
if (await this.models.workspaceFeature.has(workspaceId, 'team_plan_v1')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const ownerId = await this.getWorkspaceOwnerId(workspaceId);
|
||||
if (!ownerId) return false;
|
||||
|
||||
if (await this.hasAiPlan(userId)) return true;
|
||||
return await this.hasAiPlan(ownerId);
|
||||
}
|
||||
|
||||
async hasServerEntitlement(workspaceId: string) {
|
||||
if (env.selfhosted) return true;
|
||||
|
||||
if (await this.models.workspaceFeature.has(workspaceId, 'team_plan_v1')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const ownerId = await this.getWorkspaceOwnerId(workspaceId);
|
||||
if (!ownerId) return false;
|
||||
return await this.hasAiPlan(ownerId);
|
||||
}
|
||||
|
||||
async hasEntitlement(workspaceId: string, userId?: string) {
|
||||
const [serverEntitled, localEntitled] = await Promise.all([
|
||||
this.hasServerEntitlement(workspaceId),
|
||||
this.hasLocalEntitlement(workspaceId, userId),
|
||||
]);
|
||||
|
||||
return [serverEntitled, localEntitled] as const;
|
||||
}
|
||||
|
||||
async assertServerEntitled(workspaceId: string) {
|
||||
if (!(await this.hasServerEntitlement(workspaceId))) {
|
||||
throw new ActionForbidden('BYOK requires Pro, Team, or Believer.');
|
||||
}
|
||||
}
|
||||
|
||||
async assertLocalEntitled(workspaceId: string, userId?: string) {
|
||||
if (!(await this.hasLocalEntitlement(workspaceId, userId))) {
|
||||
throw new ActionForbidden('BYOK requires Pro, Team, or Believer.');
|
||||
}
|
||||
}
|
||||
}
|
||||
405
packages/backend/server/src/plugins/copilot/byok/resolver.ts
Normal file
405
packages/backend/server/src/plugins/copilot/byok/resolver.ts
Normal file
@@ -0,0 +1,405 @@
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
ID,
|
||||
InputType,
|
||||
Mutation,
|
||||
ObjectType,
|
||||
Parent,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
} from '@nestjs/graphql';
|
||||
import { SafeIntResolver } from 'graphql-scalars';
|
||||
|
||||
import { Throttle } from '../../../base';
|
||||
import { CurrentUser } from '../../../core/auth';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { WorkspaceType } from '../../../core/workspaces';
|
||||
import { ByokEntitlementPolicy } from './policy';
|
||||
import { ByokKeyConfig, ByokLocalLeaseProvider, ByokService } from './service';
|
||||
import { ByokKeyStorage, ByokKeyTestStatus, ByokProvider } from './types';
|
||||
|
||||
@ObjectType()
|
||||
export class WorkspaceByokKeyConfigType implements ByokKeyConfig {
|
||||
@Field(() => ID)
|
||||
id!: string;
|
||||
|
||||
@Field(() => ByokProvider)
|
||||
provider!: ByokProvider;
|
||||
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
description!: string | null;
|
||||
|
||||
@Field(() => ByokKeyStorage)
|
||||
storage!: ByokKeyStorage;
|
||||
|
||||
@Field(() => Boolean)
|
||||
configured!: boolean;
|
||||
|
||||
@Field(() => Boolean)
|
||||
enabled!: boolean;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
endpoint!: string | null;
|
||||
|
||||
@Field(() => Boolean)
|
||||
endpointEditable!: boolean;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
sortOrder!: number;
|
||||
|
||||
@Field(() => [String])
|
||||
capabilities!: string[];
|
||||
|
||||
@Field(() => ByokKeyTestStatus)
|
||||
testStatus!: ByokKeyTestStatus;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
disabledReason!: string | null;
|
||||
|
||||
@Field(() => Date, { nullable: true })
|
||||
lastTestedAt!: Date | null;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
lastTestError!: string | null;
|
||||
|
||||
@Field(() => Date, { nullable: true })
|
||||
lastUsedAt!: Date | null;
|
||||
|
||||
@Field(() => Date, { nullable: true })
|
||||
lastErrorAt!: Date | null;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
lastError!: string | null;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class WorkspaceByokCapabilityWarningType {
|
||||
@Field(() => String)
|
||||
featureKind!: string;
|
||||
|
||||
@Field(() => String)
|
||||
reason!: string;
|
||||
|
||||
@Field(() => [ByokProvider])
|
||||
requiredProviders!: ByokProvider[];
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class WorkspaceByokSettingsType {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => Boolean)
|
||||
entitled!: boolean;
|
||||
|
||||
@Field(() => Boolean)
|
||||
serverEntitled!: boolean;
|
||||
|
||||
@Field(() => Boolean)
|
||||
localEntitled!: boolean;
|
||||
|
||||
@Field(() => [String])
|
||||
entitlementRequired!: string[];
|
||||
|
||||
@Field(() => [WorkspaceByokKeyConfigType])
|
||||
keys!: WorkspaceByokKeyConfigType[];
|
||||
|
||||
@Field(() => [ByokProvider])
|
||||
allowedProviders!: ByokProvider[];
|
||||
|
||||
@Field(() => Boolean)
|
||||
localStorageSupported!: boolean;
|
||||
|
||||
@Field(() => Boolean)
|
||||
customEndpointSupported!: boolean;
|
||||
|
||||
@Field(() => Boolean)
|
||||
hasAiPlan!: boolean;
|
||||
|
||||
@Field(() => [WorkspaceByokCapabilityWarningType])
|
||||
warnings!: WorkspaceByokCapabilityWarningType[];
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class WorkspaceByokUsagePointType {
|
||||
@Field(() => Date)
|
||||
date!: Date;
|
||||
|
||||
@Field(() => String)
|
||||
featureKind!: string;
|
||||
|
||||
@Field(() => SafeIntResolver)
|
||||
totalTokens!: number;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class TestWorkspaceByokConfigResultType {
|
||||
@Field(() => Boolean)
|
||||
ok!: boolean;
|
||||
|
||||
@Field(() => ByokKeyTestStatus)
|
||||
status!: ByokKeyTestStatus;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
message!: string | null;
|
||||
}
|
||||
|
||||
@ObjectType()
|
||||
class CreateWorkspaceByokLocalLeaseResultType {
|
||||
@Field(() => String)
|
||||
leaseId!: string;
|
||||
|
||||
@Field(() => Date)
|
||||
expiresAt!: Date;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class UpsertWorkspaceByokConfigInput {
|
||||
@Field(() => ID, { nullable: true })
|
||||
id?: string;
|
||||
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => ByokProvider)
|
||||
provider!: ByokProvider;
|
||||
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
description?: string | null;
|
||||
|
||||
@Field(() => ByokKeyStorage)
|
||||
storage!: ByokKeyStorage;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
apiKey?: string | null;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
endpoint?: string | null;
|
||||
|
||||
@Field(() => SafeIntResolver, { nullable: true })
|
||||
sortOrder?: number | null;
|
||||
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
enabled?: boolean | null;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class TestWorkspaceByokConfigInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => ByokProvider)
|
||||
provider!: ByokProvider;
|
||||
|
||||
@Field(() => ByokKeyStorage)
|
||||
storage!: ByokKeyStorage;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
apiKey?: string | null;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
endpoint?: string | null;
|
||||
|
||||
@Field(() => ID, { nullable: true })
|
||||
configId?: string | null;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class ReorderWorkspaceByokConfigsInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => ByokKeyStorage)
|
||||
storage!: ByokKeyStorage;
|
||||
|
||||
@Field(() => [ID])
|
||||
ids!: string[];
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateWorkspaceByokLocalLeaseProviderInput implements ByokLocalLeaseProvider {
|
||||
@Field(() => ByokProvider)
|
||||
provider!: ByokProvider;
|
||||
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
description?: string | null;
|
||||
|
||||
@Field(() => String)
|
||||
apiKey!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
endpoint?: string | null;
|
||||
|
||||
@Field(() => SafeIntResolver, { nullable: true })
|
||||
sortOrder?: number | null;
|
||||
|
||||
@Field(() => Boolean, { nullable: true })
|
||||
enabled?: boolean | null;
|
||||
}
|
||||
|
||||
@InputType()
|
||||
class CreateWorkspaceByokLocalLeaseInput {
|
||||
@Field(() => String)
|
||||
workspaceId!: string;
|
||||
|
||||
@Field(() => [CreateWorkspaceByokLocalLeaseProviderInput])
|
||||
providers!: CreateWorkspaceByokLocalLeaseProviderInput[];
|
||||
}
|
||||
|
||||
@Resolver(() => WorkspaceType)
|
||||
export class WorkspaceByokResolver {
|
||||
constructor(
|
||||
private readonly ac: AccessController,
|
||||
private readonly entitlement: ByokEntitlementPolicy,
|
||||
private readonly byok: ByokService
|
||||
) {}
|
||||
|
||||
@ResolveField(() => WorkspaceByokSettingsType, {
|
||||
name: 'byokSettings',
|
||||
complexity: 2,
|
||||
})
|
||||
async settings(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Parent() workspace: WorkspaceType
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspace.id)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Read');
|
||||
await this.entitlement.assertManagementAccess(workspace.id, user.id);
|
||||
return await this.byok.getSettings(workspace.id, user.id);
|
||||
}
|
||||
|
||||
@ResolveField(() => [WorkspaceByokUsagePointType], {
|
||||
name: 'byokUsage',
|
||||
complexity: 2,
|
||||
})
|
||||
async usage(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Parent() workspace: WorkspaceType,
|
||||
@Args('from', { type: () => Date }) from: Date,
|
||||
@Args('to', { type: () => Date }) to: Date
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspace.id)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Read');
|
||||
await this.entitlement.assertManagementAccess(workspace.id, user.id);
|
||||
return await this.byok.getUsage(workspace.id, from, to);
|
||||
}
|
||||
|
||||
@Throttle('strict')
|
||||
@Mutation(() => TestWorkspaceByokConfigResultType)
|
||||
async testWorkspaceByokConfig(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('input') input: TestWorkspaceByokConfigInput
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(input.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
|
||||
if (input.storage === ByokKeyStorage.server) {
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
} else {
|
||||
await this.entitlement.assertLocalEntitled(input.workspaceId, user.id);
|
||||
}
|
||||
return await this.byok.testConfig({ ...input, userId: user.id });
|
||||
}
|
||||
|
||||
@Mutation(() => WorkspaceByokKeyConfigType)
|
||||
@Throttle('strict')
|
||||
async upsertWorkspaceByokConfig(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('input') input: UpsertWorkspaceByokConfigInput
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(input.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
return await this.byok.upsertConfig({ ...input, userId: user.id });
|
||||
}
|
||||
|
||||
@Mutation(() => [WorkspaceByokKeyConfigType])
|
||||
@Throttle('strict')
|
||||
async reorderWorkspaceByokConfigs(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('input') input: ReorderWorkspaceByokConfigsInput
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(input.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
return await this.byok.reorderConfigs({ ...input, userId: user.id });
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean)
|
||||
@Throttle('strict')
|
||||
async deleteWorkspaceByokConfig(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('id', { type: () => ID }) id: string,
|
||||
@Args('workspaceId', { type: () => String }) workspaceId: string
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.entitlement.assertManagementAccess(workspaceId, user.id);
|
||||
await this.entitlement.assertServerEntitled(workspaceId);
|
||||
return await this.byok.deleteConfig(workspaceId, id, user.id);
|
||||
}
|
||||
|
||||
@Mutation(() => Boolean)
|
||||
@Throttle('strict')
|
||||
async clearWorkspaceByokConfigs(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('workspaceId', { type: () => String }) workspaceId: string,
|
||||
@Args('provider', { type: () => ByokProvider, nullable: true })
|
||||
provider?: ByokProvider | null
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Settings.Update');
|
||||
await this.entitlement.assertManagementAccess(workspaceId, user.id);
|
||||
await this.entitlement.assertServerEntitled(workspaceId);
|
||||
return await this.byok.clearConfigs(workspaceId, provider, user.id);
|
||||
}
|
||||
|
||||
@Mutation(() => CreateWorkspaceByokLocalLeaseResultType)
|
||||
@Throttle('strict')
|
||||
async createWorkspaceByokLocalLease(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args('input') input: CreateWorkspaceByokLocalLeaseInput
|
||||
) {
|
||||
await this.ac
|
||||
.user(user.id)
|
||||
.workspace(input.workspaceId)
|
||||
.allowLocal()
|
||||
.assert('Workspace.Copilot');
|
||||
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
|
||||
await this.entitlement.assertLocalEntitled(input.workspaceId, user.id);
|
||||
return await this.byok.createLocalLease({ ...input, userId: user.id });
|
||||
}
|
||||
}
|
||||
883
packages/backend/server/src/plugins/copilot/byok/service.ts
Normal file
883
packages/backend/server/src/plugins/copilot/byok/service.ts
Normal file
@@ -0,0 +1,883 @@
|
||||
import { createHash, createHmac, randomUUID } from 'node:crypto';
|
||||
|
||||
import { BadRequestException, Injectable } from '@nestjs/common';
|
||||
|
||||
import { BadRequest, Cache, CryptoHelper, metrics } from '../../../base';
|
||||
import { Models } from '../../../models';
|
||||
import type { CopilotProviderProfile } from '../config';
|
||||
import { ByokEntitlementPolicy } from './policy';
|
||||
import {
|
||||
BYOK_ALLOWED_PROVIDERS,
|
||||
type ByokFeatureKind,
|
||||
ByokKeyStorage,
|
||||
ByokKeyTestStatus,
|
||||
ByokProvider,
|
||||
ByokProviderSource,
|
||||
byokProviderToCopilotType,
|
||||
isByokProvider,
|
||||
} from './types';
|
||||
|
||||
const LOCAL_LEASE_TTL_MS = 10 * 60 * 1000;
|
||||
const BYOK_PROFILE_PRIORITY_BASE = 10_000;
|
||||
const SERVER_PROFILE_PRIORITY_OFFSET = 2_000;
|
||||
const TEST_TIMEOUT_MS = 10_000;
|
||||
|
||||
export type ByokProviderRequestContext = {
|
||||
userId?: string;
|
||||
workspaceId?: string;
|
||||
byokLeaseId?: string;
|
||||
};
|
||||
|
||||
export type ByokProfileSourceFilter = {
|
||||
local?: boolean;
|
||||
server?: boolean;
|
||||
};
|
||||
|
||||
export type ByokKeyConfig = {
|
||||
id: string;
|
||||
provider: ByokProvider;
|
||||
name: string;
|
||||
description: string | null;
|
||||
storage: ByokKeyStorage;
|
||||
configured: boolean;
|
||||
enabled: boolean;
|
||||
endpoint: string | null;
|
||||
endpointEditable: boolean;
|
||||
sortOrder: number;
|
||||
capabilities: string[];
|
||||
testStatus: ByokKeyTestStatus;
|
||||
disabledReason: string | null;
|
||||
lastTestedAt: Date | null;
|
||||
lastTestError: string | null;
|
||||
lastUsedAt: Date | null;
|
||||
lastErrorAt: Date | null;
|
||||
lastError: string | null;
|
||||
};
|
||||
|
||||
export type ByokSettings = {
|
||||
workspaceId: string;
|
||||
entitled: boolean;
|
||||
serverEntitled: boolean;
|
||||
localEntitled: boolean;
|
||||
entitlementRequired: string[];
|
||||
keys: ByokKeyConfig[];
|
||||
allowedProviders: ByokProvider[];
|
||||
localStorageSupported: boolean;
|
||||
customEndpointSupported: boolean;
|
||||
hasAiPlan: boolean;
|
||||
warnings: Array<{
|
||||
featureKind: string;
|
||||
reason: string;
|
||||
requiredProviders: ByokProvider[];
|
||||
}>;
|
||||
};
|
||||
|
||||
export type ByokLocalLeaseProvider = {
|
||||
provider: ByokProvider;
|
||||
name: string;
|
||||
description?: string | null;
|
||||
apiKey: string;
|
||||
endpoint?: string | null;
|
||||
sortOrder?: number | null;
|
||||
enabled?: boolean | null;
|
||||
};
|
||||
|
||||
type LocalLeasePayload = {
|
||||
workspaceId: string;
|
||||
userId: string;
|
||||
providers: Array<
|
||||
Omit<ByokLocalLeaseProvider, 'apiKey'> & { encryptedApiKey: string }
|
||||
>;
|
||||
};
|
||||
|
||||
type LocalLeaseActive = {
|
||||
leaseId: string;
|
||||
expiresAt: string;
|
||||
};
|
||||
|
||||
type ByokProfileMeta = {
|
||||
source: ByokProviderSource.Server | ByokProviderSource.Local;
|
||||
keyId?: string;
|
||||
provider: ByokProvider;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class ByokService {
|
||||
constructor(
|
||||
private readonly models: Models,
|
||||
private readonly crypto: CryptoHelper,
|
||||
private readonly cache: Cache,
|
||||
private readonly entitlement: ByokEntitlementPolicy
|
||||
) {}
|
||||
|
||||
get customEndpointSupported() {
|
||||
return env.selfhosted;
|
||||
}
|
||||
|
||||
async getSettings(
|
||||
workspaceId: string,
|
||||
userId?: string
|
||||
): Promise<ByokSettings> {
|
||||
if (!(await this.entitlement.hasManagementAccess(workspaceId, userId))) {
|
||||
return {
|
||||
workspaceId,
|
||||
entitled: false,
|
||||
serverEntitled: false,
|
||||
localEntitled: false,
|
||||
entitlementRequired: ['Workspace owner or admin'],
|
||||
keys: [],
|
||||
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
|
||||
localStorageSupported: false,
|
||||
customEndpointSupported: this.customEndpointSupported,
|
||||
hasAiPlan: await this.entitlement.hasAiPlan(userId),
|
||||
warnings: [],
|
||||
};
|
||||
}
|
||||
|
||||
const [serverEntitled, localEntitled] =
|
||||
await this.entitlement.hasEntitlement(workspaceId, userId);
|
||||
const entitled = serverEntitled || localEntitled;
|
||||
if (!entitled) {
|
||||
return {
|
||||
workspaceId,
|
||||
entitled: false,
|
||||
serverEntitled: false,
|
||||
localEntitled: false,
|
||||
entitlementRequired: ['Pro', 'Team', 'Believer'],
|
||||
keys: [],
|
||||
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
|
||||
localStorageSupported: false,
|
||||
customEndpointSupported: this.customEndpointSupported,
|
||||
hasAiPlan: await this.entitlement.hasAiPlan(userId),
|
||||
warnings: [],
|
||||
};
|
||||
}
|
||||
|
||||
const rows = serverEntitled
|
||||
? await this.models.copilotWorkspaceByokConfig.list(workspaceId)
|
||||
: [];
|
||||
const keys = rows.map(row => this.toKeyConfig(row));
|
||||
|
||||
return {
|
||||
workspaceId,
|
||||
entitled: true,
|
||||
serverEntitled,
|
||||
localEntitled,
|
||||
entitlementRequired: ['Pro', 'Team', 'Believer'],
|
||||
keys,
|
||||
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
|
||||
localStorageSupported: false,
|
||||
customEndpointSupported: this.customEndpointSupported,
|
||||
hasAiPlan: await this.entitlement.hasAiPlan(userId),
|
||||
warnings: this.buildWarnings(keys),
|
||||
};
|
||||
}
|
||||
|
||||
async upsertConfig(input: {
|
||||
id?: string | null;
|
||||
workspaceId: string;
|
||||
provider: ByokProvider;
|
||||
name: string;
|
||||
description?: string | null;
|
||||
storage: ByokKeyStorage;
|
||||
apiKey?: string | null;
|
||||
endpoint?: string | null;
|
||||
sortOrder?: number | null;
|
||||
enabled?: boolean | null;
|
||||
userId?: string;
|
||||
}): Promise<ByokKeyConfig> {
|
||||
await this.entitlement.assertManagementAccess(
|
||||
input.workspaceId,
|
||||
input.userId
|
||||
);
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
this.assertProvider(input.provider);
|
||||
if (input.storage !== ByokKeyStorage.server) {
|
||||
throw new BadRequestException('Only server BYOK keys are persisted.');
|
||||
}
|
||||
const existing = input.id
|
||||
? await this.models.copilotWorkspaceByokConfig.get(input.id)
|
||||
: null;
|
||||
if (input.id && (!existing || existing.workspaceId !== input.workspaceId)) {
|
||||
throw new BadRequest('BYOK config not found.');
|
||||
}
|
||||
const encryptedApiKey = input.apiKey
|
||||
? this.crypto.encrypt(input.apiKey)
|
||||
: undefined;
|
||||
|
||||
if (!input.id && !encryptedApiKey) {
|
||||
throw new BadRequestException('apiKey is required.');
|
||||
}
|
||||
|
||||
const description =
|
||||
input.description !== undefined
|
||||
? input.description?.trim() || null
|
||||
: (existing?.description ?? null);
|
||||
const endpoint =
|
||||
input.endpoint !== undefined
|
||||
? this.normalizeEndpoint(input.endpoint)
|
||||
: (existing?.endpoint ?? null);
|
||||
const sortOrder = input.sortOrder ?? existing?.sortOrder ?? 0;
|
||||
const enabled = input.enabled ?? existing?.enabled ?? true;
|
||||
|
||||
const row = await this.models.copilotWorkspaceByokConfig.upsert({
|
||||
id: input.id,
|
||||
workspaceId: input.workspaceId,
|
||||
provider: input.provider,
|
||||
name: input.name.trim(),
|
||||
description,
|
||||
encryptedApiKey,
|
||||
endpoint,
|
||||
sortOrder,
|
||||
enabled,
|
||||
userId: input.userId,
|
||||
});
|
||||
|
||||
return this.toKeyConfig(row);
|
||||
}
|
||||
|
||||
async reorderConfigs(input: {
|
||||
workspaceId: string;
|
||||
storage: ByokKeyStorage;
|
||||
ids: string[];
|
||||
userId?: string;
|
||||
}) {
|
||||
await this.entitlement.assertManagementAccess(
|
||||
input.workspaceId,
|
||||
input.userId
|
||||
);
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
if (input.storage !== ByokKeyStorage.server) {
|
||||
throw new BadRequestException('Only server BYOK keys are persisted.');
|
||||
}
|
||||
await this.models.copilotWorkspaceByokConfig.reorder(
|
||||
input.workspaceId,
|
||||
input.ids,
|
||||
input.userId
|
||||
);
|
||||
return (await this.getSettings(input.workspaceId, input.userId)).keys;
|
||||
}
|
||||
|
||||
async deleteConfig(workspaceId: string, id: string, _userId?: string) {
|
||||
await this.entitlement.assertManagementAccess(workspaceId, _userId);
|
||||
await this.entitlement.assertServerEntitled(workspaceId);
|
||||
await this.models.copilotWorkspaceByokConfig.delete(workspaceId, id);
|
||||
return true;
|
||||
}
|
||||
|
||||
async clearConfigs(
|
||||
workspaceId: string,
|
||||
provider: ByokProvider | null | undefined,
|
||||
_userId?: string
|
||||
) {
|
||||
await this.entitlement.assertManagementAccess(workspaceId, _userId);
|
||||
await this.entitlement.assertServerEntitled(workspaceId);
|
||||
await this.models.copilotWorkspaceByokConfig.clear(workspaceId, provider);
|
||||
return true;
|
||||
}
|
||||
|
||||
async testConfig(input: {
|
||||
workspaceId: string;
|
||||
provider: ByokProvider;
|
||||
storage: ByokKeyStorage;
|
||||
apiKey?: string | null;
|
||||
endpoint?: string | null;
|
||||
configId?: string | null;
|
||||
userId?: string;
|
||||
}) {
|
||||
await this.entitlement.assertManagementAccess(
|
||||
input.workspaceId,
|
||||
input.userId
|
||||
);
|
||||
if (input.storage === ByokKeyStorage.server) {
|
||||
await this.entitlement.assertServerEntitled(input.workspaceId);
|
||||
} else {
|
||||
await this.entitlement.assertLocalEntitled(
|
||||
input.workspaceId,
|
||||
input.userId
|
||||
);
|
||||
}
|
||||
this.assertProvider(input.provider);
|
||||
let apiKey = input.apiKey;
|
||||
let endpoint = this.normalizeEndpoint(input.endpoint);
|
||||
if (!apiKey && input.configId && input.storage === ByokKeyStorage.server) {
|
||||
const config = await this.models.copilotWorkspaceByokConfig.get(
|
||||
input.configId
|
||||
);
|
||||
if (
|
||||
!config ||
|
||||
config.workspaceId !== input.workspaceId ||
|
||||
config.provider !== input.provider
|
||||
) {
|
||||
throw new BadRequestException('BYOK config not found.');
|
||||
}
|
||||
apiKey = this.crypto.decrypt(config.encryptedApiKey);
|
||||
endpoint =
|
||||
input.endpoint !== undefined
|
||||
? endpoint
|
||||
: this.normalizeEndpoint(config.endpoint);
|
||||
}
|
||||
if (!apiKey) {
|
||||
throw new BadRequestException('apiKey is required.');
|
||||
}
|
||||
|
||||
try {
|
||||
await this.runProviderProbe(input.provider, apiKey, endpoint);
|
||||
if (input.configId && input.storage === ByokKeyStorage.server) {
|
||||
await this.models.copilotWorkspaceByokConfig.markValidated(
|
||||
input.workspaceId,
|
||||
input.configId,
|
||||
input.userId
|
||||
);
|
||||
}
|
||||
metrics.ai.counter('byok_test_key').add(1, {
|
||||
workspace: input.workspaceId,
|
||||
provider: input.provider,
|
||||
storage: input.storage,
|
||||
result: 'passed',
|
||||
});
|
||||
return { ok: true, status: ByokKeyTestStatus.passed, message: null };
|
||||
} catch (error) {
|
||||
const message = this.sanitizeError(error);
|
||||
if (input.configId && input.storage === ByokKeyStorage.server) {
|
||||
await this.models.copilotWorkspaceByokConfig.markFailure(
|
||||
input.workspaceId,
|
||||
input.configId,
|
||||
message
|
||||
);
|
||||
}
|
||||
metrics.ai.counter('byok_test_key').add(1, {
|
||||
workspace: input.workspaceId,
|
||||
provider: input.provider,
|
||||
storage: input.storage,
|
||||
result: 'failed',
|
||||
});
|
||||
return { ok: false, status: ByokKeyTestStatus.failed, message };
|
||||
}
|
||||
}
|
||||
|
||||
async createLocalLease(input: {
|
||||
workspaceId: string;
|
||||
providers: ByokLocalLeaseProvider[];
|
||||
userId: string;
|
||||
}) {
|
||||
await this.entitlement.assertManagementAccess(
|
||||
input.workspaceId,
|
||||
input.userId
|
||||
);
|
||||
await this.entitlement.assertLocalEntitled(input.workspaceId, input.userId);
|
||||
const providers = input.providers.map(provider => {
|
||||
this.assertProvider(provider.provider);
|
||||
const endpoint = this.normalizeEndpoint(provider.endpoint);
|
||||
return { ...provider, endpoint };
|
||||
});
|
||||
const activeCacheKey = this.localLeaseActiveCacheKey({
|
||||
...input,
|
||||
providers,
|
||||
});
|
||||
const activeLease = await this.getActiveLocalLease(activeCacheKey);
|
||||
if (activeLease) return activeLease;
|
||||
|
||||
const leaseId = randomUUID();
|
||||
const expiresAt = new Date(Date.now() + LOCAL_LEASE_TTL_MS);
|
||||
const payload: LocalLeasePayload = {
|
||||
workspaceId: input.workspaceId,
|
||||
userId: input.userId,
|
||||
providers: providers.map(provider => ({
|
||||
provider: provider.provider,
|
||||
name: provider.name,
|
||||
description: provider.description,
|
||||
encryptedApiKey: this.crypto.encrypt(provider.apiKey),
|
||||
endpoint: provider.endpoint,
|
||||
sortOrder: provider.sortOrder,
|
||||
enabled: provider.enabled,
|
||||
})),
|
||||
};
|
||||
await this.cache.set(this.leaseCacheKey(leaseId), payload, {
|
||||
ttl: LOCAL_LEASE_TTL_MS,
|
||||
});
|
||||
const registered = await this.cache.setnx<LocalLeaseActive>(
|
||||
activeCacheKey,
|
||||
{ leaseId, expiresAt: expiresAt.toISOString() },
|
||||
{ ttl: LOCAL_LEASE_TTL_MS }
|
||||
);
|
||||
if (!registered) {
|
||||
const current = await this.getActiveLocalLease(activeCacheKey);
|
||||
if (current) {
|
||||
await this.cache.delete(this.leaseCacheKey(leaseId));
|
||||
return current;
|
||||
}
|
||||
}
|
||||
return { leaseId, expiresAt };
|
||||
}
|
||||
|
||||
async getProfiles(
|
||||
context: ByokProviderRequestContext = {},
|
||||
sources: ByokProfileSourceFilter = { local: true, server: true }
|
||||
): Promise<CopilotProviderProfile[]> {
|
||||
if (!context.workspaceId) {
|
||||
return [];
|
||||
}
|
||||
const [localEntitled, serverEntitled] = await Promise.all([
|
||||
this.entitlement.hasLocalEntitlement(context.workspaceId, context.userId),
|
||||
this.entitlement.hasServerEntitlement(context.workspaceId),
|
||||
]);
|
||||
const [localProfiles, serverProfiles] = await Promise.all([
|
||||
sources.local && localEntitled
|
||||
? this.getLocalProfiles(context)
|
||||
: Promise.resolve([]),
|
||||
sources.server && serverEntitled
|
||||
? this.getServerProfiles(context.workspaceId)
|
||||
: Promise.resolve([]),
|
||||
]);
|
||||
|
||||
return [...localProfiles, ...serverProfiles];
|
||||
}
|
||||
|
||||
async recordUsage(input: {
|
||||
workspaceId?: string;
|
||||
userId?: string;
|
||||
providerId?: string;
|
||||
model?: string | null;
|
||||
featureKind: ByokFeatureKind;
|
||||
sessionId?: string;
|
||||
taskId?: string;
|
||||
actionId?: string;
|
||||
billingUnitId?: string;
|
||||
usage?: {
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
cached_tokens?: number;
|
||||
};
|
||||
}) {
|
||||
if (!input.workspaceId || !input.providerId) return;
|
||||
const meta = this.parseProfileMeta(input.providerId, input.workspaceId);
|
||||
if (!meta) return;
|
||||
|
||||
metrics.ai.counter('byok_usage').add(1, {
|
||||
workspace: input.workspaceId,
|
||||
provider: meta.provider,
|
||||
source: meta.source,
|
||||
feature: input.featureKind,
|
||||
});
|
||||
await this.models.copilotUsage.create({
|
||||
workspaceId: input.workspaceId,
|
||||
userId: input.userId,
|
||||
provider: meta.provider,
|
||||
providerSource: meta.source,
|
||||
featureKind: input.featureKind,
|
||||
model: input.model ?? null,
|
||||
sessionId: input.sessionId,
|
||||
taskId: input.taskId,
|
||||
actionId: input.actionId,
|
||||
billingUnitId: input.billingUnitId,
|
||||
promptTokens: input.usage?.prompt_tokens ?? 0,
|
||||
completionTokens: input.usage?.completion_tokens ?? 0,
|
||||
totalTokens: input.usage?.total_tokens ?? 0,
|
||||
cachedTokens: input.usage?.cached_tokens ?? 0,
|
||||
});
|
||||
if (meta.source === ByokProviderSource.Server && meta.keyId) {
|
||||
await this.models.copilotWorkspaceByokConfig.touchUsed(
|
||||
input.workspaceId,
|
||||
meta.keyId
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async recordProviderFailure(input: {
|
||||
workspaceId?: string;
|
||||
providerId?: string;
|
||||
featureKind: ByokFeatureKind;
|
||||
error: unknown;
|
||||
}) {
|
||||
if (!input.workspaceId || !input.providerId) return;
|
||||
const meta = this.parseProfileMeta(input.providerId, input.workspaceId);
|
||||
if (!meta) return;
|
||||
|
||||
const message = this.sanitizeError(input.error);
|
||||
metrics.ai.counter('byok_route_failure').add(1, {
|
||||
workspace: input.workspaceId,
|
||||
provider: meta.provider,
|
||||
source: meta.source,
|
||||
feature: input.featureKind,
|
||||
});
|
||||
if (meta.source === ByokProviderSource.Server && meta.keyId) {
|
||||
await this.models.copilotWorkspaceByokConfig.markFailure(
|
||||
input.workspaceId,
|
||||
meta.keyId,
|
||||
message
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async getUsage(workspaceId: string, from: Date, to: Date) {
|
||||
return await this.models.copilotUsage.aggregateByDay({
|
||||
workspaceId,
|
||||
from,
|
||||
to,
|
||||
providerSources: [ByokProviderSource.Server, ByokProviderSource.Local],
|
||||
});
|
||||
}
|
||||
|
||||
private async getServerProfiles(workspaceId: string) {
|
||||
const rows =
|
||||
await this.models.copilotWorkspaceByokConfig.listEnabled(workspaceId);
|
||||
|
||||
return rows
|
||||
.filter(row => isByokProvider(row.provider))
|
||||
.map((row, index): CopilotProviderProfile => {
|
||||
const provider = row.provider as ByokProvider;
|
||||
return {
|
||||
id: this.profileId(workspaceId, provider, row.id, 'server'),
|
||||
type: byokProviderToCopilotType(provider),
|
||||
priority:
|
||||
BYOK_PROFILE_PRIORITY_BASE - SERVER_PROFILE_PRIORITY_OFFSET - index,
|
||||
config: this.providerConfig(
|
||||
provider,
|
||||
row.encryptedApiKey,
|
||||
row.endpoint
|
||||
),
|
||||
} as CopilotProviderProfile;
|
||||
});
|
||||
}
|
||||
|
||||
private async getLocalProfiles(context: ByokProviderRequestContext) {
|
||||
if (!context.byokLeaseId || !context.workspaceId || !context.userId) {
|
||||
return [];
|
||||
}
|
||||
if (
|
||||
!(await this.entitlement.hasManagementAccess(
|
||||
context.workspaceId,
|
||||
context.userId
|
||||
))
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
const lease = await this.cache.get<LocalLeasePayload>(
|
||||
this.leaseCacheKey(context.byokLeaseId)
|
||||
);
|
||||
if (
|
||||
!lease ||
|
||||
lease.workspaceId !== context.workspaceId ||
|
||||
lease.userId !== context.userId
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
return lease.providers
|
||||
.filter(provider => provider.enabled !== false)
|
||||
.map((provider, index): CopilotProviderProfile => {
|
||||
return {
|
||||
id: this.profileId(
|
||||
context.workspaceId ?? lease.workspaceId,
|
||||
provider.provider,
|
||||
`${index}`,
|
||||
'local'
|
||||
),
|
||||
type: byokProviderToCopilotType(provider.provider),
|
||||
priority: BYOK_PROFILE_PRIORITY_BASE - index,
|
||||
config: this.providerConfig(
|
||||
provider.provider,
|
||||
provider.encryptedApiKey,
|
||||
provider.endpoint ?? null
|
||||
),
|
||||
} as CopilotProviderProfile;
|
||||
});
|
||||
}
|
||||
|
||||
private providerConfig(
|
||||
provider: ByokProvider,
|
||||
encryptedApiKey: string,
|
||||
endpoint: string | null
|
||||
) {
|
||||
const apiKey = this.crypto.decrypt(encryptedApiKey);
|
||||
switch (provider) {
|
||||
case ByokProvider.openai:
|
||||
case ByokProvider.gemini:
|
||||
case ByokProvider.anthropic:
|
||||
return { apiKey, ...(endpoint ? { baseURL: endpoint } : {}) };
|
||||
case ByokProvider.fal:
|
||||
return { apiKey };
|
||||
}
|
||||
}
|
||||
|
||||
private profileId(
|
||||
workspaceId: string,
|
||||
provider: ByokProvider,
|
||||
keyId: string,
|
||||
storage: 'server' | 'local'
|
||||
) {
|
||||
const hash = this.workspaceHash(workspaceId);
|
||||
const sanitizedKeyId = keyId.replaceAll(/[^a-zA-Z0-9-_]/g, '');
|
||||
return storage === 'local'
|
||||
? `byok-${hash}-${provider}-local-${sanitizedKeyId}`
|
||||
: `byok-${hash}-${provider}-${sanitizedKeyId}`;
|
||||
}
|
||||
|
||||
parseProfileMeta(
|
||||
providerId: string,
|
||||
workspaceId?: string
|
||||
): ByokProfileMeta | null {
|
||||
const match =
|
||||
/^byok-([a-f0-9]{12})-(openai|anthropic|gemini|fal)-(.+)$/.exec(
|
||||
providerId
|
||||
);
|
||||
if (!match) return null;
|
||||
if (workspaceId && match[1] !== this.workspaceHash(workspaceId)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const keyId = match[3];
|
||||
return {
|
||||
provider: match[2] as ByokProvider,
|
||||
source: keyId.startsWith('local-')
|
||||
? ByokProviderSource.Local
|
||||
: ByokProviderSource.Server,
|
||||
keyId: keyId.startsWith('local-') ? undefined : keyId,
|
||||
};
|
||||
}
|
||||
|
||||
private toKeyConfig(row: {
|
||||
id: string;
|
||||
provider: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
endpoint: string | null;
|
||||
sortOrder: number;
|
||||
enabled: boolean;
|
||||
disabledReason: string | null;
|
||||
lastValidatedAt: Date | null;
|
||||
lastValidationError: string | null;
|
||||
lastUsedAt: Date | null;
|
||||
lastErrorAt: Date | null;
|
||||
lastError: string | null;
|
||||
}): ByokKeyConfig {
|
||||
const provider = row.provider as ByokProvider;
|
||||
return {
|
||||
id: row.id,
|
||||
provider,
|
||||
name: row.name,
|
||||
description: row.description,
|
||||
storage: ByokKeyStorage.server,
|
||||
configured: true,
|
||||
enabled: row.enabled,
|
||||
endpoint: row.endpoint,
|
||||
endpointEditable: this.customEndpointSupported,
|
||||
sortOrder: row.sortOrder,
|
||||
capabilities: this.capabilities(provider, 'server'),
|
||||
testStatus: row.lastValidationError
|
||||
? ByokKeyTestStatus.failed
|
||||
: row.lastValidatedAt
|
||||
? ByokKeyTestStatus.passed
|
||||
: ByokKeyTestStatus.untested,
|
||||
disabledReason: row.disabledReason,
|
||||
lastTestedAt: row.lastValidatedAt,
|
||||
lastTestError: row.lastValidationError,
|
||||
lastUsedAt: row.lastUsedAt,
|
||||
lastErrorAt: row.lastErrorAt,
|
||||
lastError: row.lastError,
|
||||
};
|
||||
}
|
||||
|
||||
private capabilities(provider: ByokProvider, storage: 'server' | 'local') {
|
||||
switch (provider) {
|
||||
case ByokProvider.openai:
|
||||
return ['Text', 'Image input', 'Actions', 'Image generate'];
|
||||
case ByokProvider.anthropic:
|
||||
return ['Text', 'Image input'];
|
||||
case ByokProvider.gemini:
|
||||
return storage === 'server'
|
||||
? [
|
||||
'Text',
|
||||
'Image input',
|
||||
'Actions',
|
||||
'Image generate',
|
||||
'Transcript',
|
||||
'Indexing',
|
||||
]
|
||||
: ['Text', 'Image input', 'Actions', 'Image generate'];
|
||||
case ByokProvider.fal:
|
||||
return ['Image generate'];
|
||||
}
|
||||
}
|
||||
|
||||
private buildWarnings(keys: ByokKeyConfig[]) {
|
||||
const activeServerGemini = keys.some(
|
||||
key =>
|
||||
key.provider === ByokProvider.gemini &&
|
||||
key.storage === ByokKeyStorage.server &&
|
||||
key.enabled
|
||||
);
|
||||
if (activeServerGemini) {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
{
|
||||
featureKind: 'transcript',
|
||||
reason:
|
||||
'Transcript and workspace indexing require a server Gemini BYOK key or AFFiNE AI plan fallback.',
|
||||
requiredProviders: [ByokProvider.gemini],
|
||||
},
|
||||
{
|
||||
featureKind: 'workspace_indexing',
|
||||
reason:
|
||||
'Workspace indexing requires a server Gemini BYOK key or AFFiNE AI plan fallback.',
|
||||
requiredProviders: [ByokProvider.gemini],
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
private normalizeEndpoint(endpoint?: string | null) {
|
||||
if (!endpoint) return null;
|
||||
if (!this.customEndpointSupported) {
|
||||
throw new BadRequestException('Custom BYOK endpoint is not supported.');
|
||||
}
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(endpoint);
|
||||
} catch {
|
||||
throw new BadRequestException('Invalid BYOK endpoint.');
|
||||
}
|
||||
if (!['https:', 'http:'].includes(parsed.protocol)) {
|
||||
throw new BadRequestException('BYOK endpoint must use HTTP or HTTPS.');
|
||||
}
|
||||
return parsed.toString().replace(/\/$/, '');
|
||||
}
|
||||
|
||||
private assertProvider(provider: ByokProvider) {
|
||||
if (!BYOK_ALLOWED_PROVIDERS.includes(provider)) {
|
||||
throw new BadRequestException('Unsupported BYOK provider.');
|
||||
}
|
||||
}
|
||||
|
||||
private async runProviderProbe(
|
||||
provider: ByokProvider,
|
||||
apiKey: string,
|
||||
endpoint: string | null
|
||||
) {
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), TEST_TIMEOUT_MS);
|
||||
try {
|
||||
const request = this.buildProbeRequest(provider, apiKey, endpoint);
|
||||
const response = await fetch(request.url, {
|
||||
method: request.method,
|
||||
headers: request.headers as unknown as Record<string, string>,
|
||||
signal: controller.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new BadRequestException(
|
||||
this.providerProbeFailureMessage(response.status)
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
private buildProbeRequest(
|
||||
provider: ByokProvider,
|
||||
apiKey: string,
|
||||
endpoint: string | null
|
||||
) {
|
||||
switch (provider) {
|
||||
case ByokProvider.openai:
|
||||
return {
|
||||
method: 'GET',
|
||||
url: `${endpoint ?? 'https://api.openai.com/v1'}/models`,
|
||||
headers: { Authorization: `Bearer ${apiKey}` },
|
||||
};
|
||||
case ByokProvider.anthropic:
|
||||
return {
|
||||
method: 'GET',
|
||||
url: `${endpoint ?? 'https://api.anthropic.com/v1'}/models`,
|
||||
headers: {
|
||||
'x-api-key': apiKey,
|
||||
'anthropic-version': '2023-06-01',
|
||||
},
|
||||
};
|
||||
case ByokProvider.gemini:
|
||||
return {
|
||||
method: 'GET',
|
||||
url: `${endpoint ?? 'https://generativelanguage.googleapis.com/v1beta'}/models`,
|
||||
headers: { 'x-goog-api-key': apiKey },
|
||||
};
|
||||
case ByokProvider.fal:
|
||||
return {
|
||||
method: 'GET',
|
||||
url: 'https://api.fal.ai/v1/models?limit=10',
|
||||
headers: { Authorization: `Key ${apiKey}` },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private sanitizeError(error: unknown) {
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
return 'Provider key test timed out.';
|
||||
}
|
||||
if (error instanceof BadRequestException && error.message) {
|
||||
return error.message.slice(0, 300);
|
||||
}
|
||||
return 'Provider request failed.';
|
||||
}
|
||||
|
||||
private providerProbeFailureMessage(status: number) {
|
||||
switch (status) {
|
||||
case 401:
|
||||
return 'Provider rejected the BYOK key.';
|
||||
case 403:
|
||||
return 'Provider rejected the BYOK key permissions.';
|
||||
case 404:
|
||||
return 'Provider probe endpoint was not found.';
|
||||
case 429:
|
||||
return 'Provider rate limit exceeded while testing the key.';
|
||||
default:
|
||||
return status >= 500
|
||||
? 'Provider service is unavailable.'
|
||||
: `Provider key test failed with HTTP ${status}.`;
|
||||
}
|
||||
}
|
||||
|
||||
private workspaceHash(workspaceId: string) {
|
||||
return createHash('sha256').update(workspaceId).digest('hex').slice(0, 12);
|
||||
}
|
||||
|
||||
private leaseCacheKey(leaseId: string) {
|
||||
return `copilot:byok:lease:${leaseId}`;
|
||||
}
|
||||
|
||||
private async getActiveLocalLease(activeCacheKey: string) {
|
||||
const active = await this.cache.get<LocalLeaseActive>(activeCacheKey);
|
||||
if (!active) return null;
|
||||
if (await this.cache.has(this.leaseCacheKey(active.leaseId))) {
|
||||
return { leaseId: active.leaseId, expiresAt: new Date(active.expiresAt) };
|
||||
}
|
||||
await this.cache.delete(activeCacheKey);
|
||||
return null;
|
||||
}
|
||||
|
||||
private localLeaseActiveCacheKey(input: {
|
||||
workspaceId: string;
|
||||
userId: string;
|
||||
providers: ByokLocalLeaseProvider[];
|
||||
}) {
|
||||
const fingerprint = createHmac(
|
||||
'sha256',
|
||||
this.crypto.keyPair.sha256.privateKey
|
||||
)
|
||||
.update(
|
||||
JSON.stringify(
|
||||
input.providers.map(provider => ({
|
||||
provider: provider.provider,
|
||||
name: provider.name,
|
||||
description: provider.description ?? null,
|
||||
apiKey: provider.apiKey,
|
||||
endpoint: provider.endpoint ?? null,
|
||||
sortOrder: provider.sortOrder ?? 0,
|
||||
enabled: provider.enabled ?? true,
|
||||
}))
|
||||
)
|
||||
)
|
||||
.digest('hex');
|
||||
return `copilot:byok:lease:active:${input.workspaceId}:${input.userId}:${fingerprint}`;
|
||||
}
|
||||
}
|
||||
79
packages/backend/server/src/plugins/copilot/byok/types.ts
Normal file
79
packages/backend/server/src/plugins/copilot/byok/types.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { registerEnumType } from '@nestjs/graphql';
|
||||
|
||||
import { CopilotProviderType } from '../providers/types';
|
||||
|
||||
export enum ByokProvider {
|
||||
openai = 'openai',
|
||||
anthropic = 'anthropic',
|
||||
gemini = 'gemini',
|
||||
fal = 'fal',
|
||||
}
|
||||
|
||||
export enum ByokKeyStorage {
|
||||
server = 'server',
|
||||
local = 'local',
|
||||
}
|
||||
|
||||
export enum ByokKeyTestStatus {
|
||||
untested = 'untested',
|
||||
passed = 'passed',
|
||||
failed = 'failed',
|
||||
}
|
||||
|
||||
export enum ByokProviderSource {
|
||||
Server = 'byok_server',
|
||||
Local = 'byok_local',
|
||||
AffinePlan = 'affine_plan',
|
||||
}
|
||||
|
||||
export type ByokFeatureKind =
|
||||
| 'chat'
|
||||
| 'action'
|
||||
| 'image'
|
||||
| 'embedding'
|
||||
| 'rerank'
|
||||
| 'transcript'
|
||||
| 'workspace_indexing';
|
||||
|
||||
export const BYOK_ALLOWED_PROVIDERS = [
|
||||
ByokProvider.openai,
|
||||
ByokProvider.anthropic,
|
||||
ByokProvider.gemini,
|
||||
ByokProvider.fal,
|
||||
] as const;
|
||||
|
||||
export function byokProviderToCopilotType(provider: ByokProvider) {
|
||||
switch (provider) {
|
||||
case ByokProvider.openai:
|
||||
return CopilotProviderType.OpenAI;
|
||||
case ByokProvider.anthropic:
|
||||
return CopilotProviderType.Anthropic;
|
||||
case ByokProvider.gemini:
|
||||
return CopilotProviderType.Gemini;
|
||||
case ByokProvider.fal:
|
||||
return CopilotProviderType.FAL;
|
||||
}
|
||||
}
|
||||
|
||||
export function copilotTypeToByokProvider(type: CopilotProviderType) {
|
||||
switch (type) {
|
||||
case CopilotProviderType.OpenAI:
|
||||
return ByokProvider.openai;
|
||||
case CopilotProviderType.Anthropic:
|
||||
return ByokProvider.anthropic;
|
||||
case CopilotProviderType.Gemini:
|
||||
return ByokProvider.gemini;
|
||||
case CopilotProviderType.FAL:
|
||||
return ByokProvider.fal;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function isByokProvider(value: string): value is ByokProvider {
|
||||
return (BYOK_ALLOWED_PROVIDERS as readonly string[]).includes(value);
|
||||
}
|
||||
|
||||
registerEnumType(ByokProvider, { name: 'ByokProvider' });
|
||||
registerEnumType(ByokKeyStorage, { name: 'ByokKeyStorage' });
|
||||
registerEnumType(ByokKeyTestStatus, { name: 'ByokKeyTestStatus' });
|
||||
@@ -12,9 +12,7 @@ import {
|
||||
import { CloudflareWorkersAIConfig } from './providers/cloudflare';
|
||||
import type { FalConfig } from './providers/fal';
|
||||
import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini';
|
||||
import { MorphConfig } from './providers/morph';
|
||||
import { OpenAIConfig } from './providers/openai';
|
||||
import { PerplexityConfig } from './providers/perplexity';
|
||||
import {
|
||||
CopilotProviderType,
|
||||
ModelOutputType,
|
||||
@@ -27,10 +25,8 @@ export type CopilotProviderConfigMap = {
|
||||
[CopilotProviderType.FAL]: FalConfig;
|
||||
[CopilotProviderType.Gemini]: GeminiGenerativeConfig;
|
||||
[CopilotProviderType.GeminiVertex]: GeminiVertexConfig;
|
||||
[CopilotProviderType.Perplexity]: PerplexityConfig;
|
||||
[CopilotProviderType.Anthropic]: AnthropicOfficialConfig;
|
||||
[CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig;
|
||||
[CopilotProviderType.Morph]: MorphConfig;
|
||||
};
|
||||
|
||||
export type ProviderSpecificConfig =
|
||||
@@ -138,20 +134,11 @@ const VertexProviderConfigShape = z.object({
|
||||
fetch: z.any().optional(),
|
||||
});
|
||||
|
||||
const PerplexityConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
endpoint: z.string().optional(),
|
||||
});
|
||||
|
||||
const AnthropicOfficialConfigShape = z.object({
|
||||
apiKey: z.string(),
|
||||
baseURL: z.string().optional(),
|
||||
});
|
||||
|
||||
const MorphConfigShape = z.object({
|
||||
apiKey: z.string().optional(),
|
||||
});
|
||||
|
||||
const CopilotProviderProfileShape = z.discriminatedUnion('type', [
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.OpenAI),
|
||||
@@ -173,10 +160,6 @@ const CopilotProviderProfileShape = z.discriminatedUnion('type', [
|
||||
type: z.literal(CopilotProviderType.GeminiVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Perplexity),
|
||||
config: PerplexityConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Anthropic),
|
||||
config: AnthropicOfficialConfigShape,
|
||||
@@ -185,10 +168,6 @@ const CopilotProviderProfileShape = z.discriminatedUnion('type', [
|
||||
type: z.literal(CopilotProviderType.AnthropicVertex),
|
||||
config: VertexProviderConfigShape,
|
||||
}),
|
||||
CopilotProviderProfileBaseShape.extend({
|
||||
type: z.literal(CopilotProviderType.Morph),
|
||||
config: MorphConfigShape,
|
||||
}),
|
||||
]);
|
||||
|
||||
const CopilotProviderDefaultsShape = z.object({
|
||||
@@ -205,6 +184,13 @@ declare global {
|
||||
interface AppConfigSchema {
|
||||
copilot: {
|
||||
enabled: boolean;
|
||||
byok: {
|
||||
enabled: ConfigItem<boolean>;
|
||||
allowedProviders: ConfigItem<
|
||||
Array<'openai' | 'anthropic' | 'gemini' | 'fal'>
|
||||
>;
|
||||
allowCustomEndpoint: ConfigItem<boolean>;
|
||||
};
|
||||
unsplash: ConfigItem<{
|
||||
key: string;
|
||||
}>;
|
||||
@@ -220,10 +206,8 @@ declare global {
|
||||
fal: ConfigItem<FalConfig>;
|
||||
gemini: ConfigItem<GeminiGenerativeConfig>;
|
||||
geminiVertex: ConfigItem<GeminiVertexConfig>;
|
||||
perplexity: ConfigItem<PerplexityConfig>;
|
||||
anthropic: ConfigItem<AnthropicOfficialConfig>;
|
||||
anthropicVertex: ConfigItem<AnthropicVertexConfig>;
|
||||
morph: ConfigItem<MorphConfig>;
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -234,6 +218,21 @@ defineModuleConfig('copilot', {
|
||||
desc: 'Whether to enable the copilot plugin. <br> Document: <a href="https://docs.affine.pro/self-host-affine/administer/ai" target="_blank">https://docs.affine.pro/self-host-affine/administer/ai</a>',
|
||||
default: false,
|
||||
},
|
||||
'byok.enabled': {
|
||||
desc: 'Whether to enable workspace BYOK.',
|
||||
default: true,
|
||||
shape: z.boolean(),
|
||||
},
|
||||
'byok.allowedProviders': {
|
||||
desc: 'The allowlist for workspace BYOK providers.',
|
||||
default: ['openai', 'anthropic', 'gemini', 'fal'],
|
||||
shape: z.array(z.enum(['openai', 'anthropic', 'gemini', 'fal'])),
|
||||
},
|
||||
'byok.allowCustomEndpoint': {
|
||||
desc: 'Whether workspace BYOK custom endpoints are accepted.',
|
||||
default: false,
|
||||
shape: z.boolean(),
|
||||
},
|
||||
'providers.profiles': {
|
||||
desc: 'The profile list for copilot providers.',
|
||||
default: [],
|
||||
@@ -277,12 +276,6 @@ defineModuleConfig('copilot', {
|
||||
default: {},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.perplexity': {
|
||||
desc: 'The config for the perplexity provider.',
|
||||
default: {
|
||||
apiKey: '',
|
||||
},
|
||||
},
|
||||
'providers.anthropic': {
|
||||
desc: 'The config for the anthropic provider.',
|
||||
default: {
|
||||
@@ -295,10 +288,6 @@ defineModuleConfig('copilot', {
|
||||
default: {},
|
||||
schema: VertexSchema,
|
||||
},
|
||||
'providers.morph': {
|
||||
desc: 'The config for the morph provider.',
|
||||
default: {},
|
||||
},
|
||||
unsplash: {
|
||||
desc: 'The config for the unsplash key.',
|
||||
default: {
|
||||
|
||||
@@ -15,7 +15,11 @@ import {
|
||||
Models,
|
||||
} from '../../../models';
|
||||
import { CopilotEmbeddingClientService } from '../embedding/client';
|
||||
import type { EmbeddingClient } from '../embedding/types';
|
||||
import type {
|
||||
EmbeddingCallOptions,
|
||||
EmbeddingClient,
|
||||
EmbeddingRouteContext,
|
||||
} from '../embedding/types';
|
||||
import { ContextSession } from './session';
|
||||
|
||||
const CONTEXT_SESSION_KEY = 'context-session';
|
||||
@@ -62,6 +66,14 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
return this.client ?? this.embeddingClients.getClient();
|
||||
}
|
||||
|
||||
private embeddingOptions(
|
||||
workspaceId: string,
|
||||
signal?: AbortSignal,
|
||||
routeContext: EmbeddingRouteContext = {}
|
||||
): EmbeddingCallOptions {
|
||||
return { workspaceId, signal, ...routeContext, featureKind: 'embedding' };
|
||||
}
|
||||
|
||||
private async saveConfig(
|
||||
contextId: string,
|
||||
config: ContextConfig,
|
||||
@@ -172,11 +184,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.5
|
||||
threshold: number = 0.5,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
) {
|
||||
const client = this.embeddingClient;
|
||||
if (!client) return [];
|
||||
const embedding = await client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(workspaceId, signal, routeContext);
|
||||
const embedding = await client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const blobChunks = await this.models.copilotWorkspace.matchBlobEmbedding(
|
||||
@@ -187,7 +201,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
);
|
||||
if (!blobChunks.length) return [];
|
||||
|
||||
return await client.reRank(content, blobChunks, topK, signal);
|
||||
return await client.reRank(content, blobChunks, topK, options);
|
||||
}
|
||||
|
||||
async matchWorkspaceFiles(
|
||||
@@ -195,11 +209,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.5
|
||||
threshold: number = 0.5,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
) {
|
||||
const client = this.embeddingClient;
|
||||
if (!client) return [];
|
||||
const embedding = await client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(workspaceId, signal, routeContext);
|
||||
const embedding = await client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const fileChunks = await this.models.copilotWorkspace.matchFileEmbedding(
|
||||
@@ -210,7 +226,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
);
|
||||
if (!fileChunks.length) return [];
|
||||
|
||||
return await client.reRank(content, fileChunks, topK, signal);
|
||||
return await client.reRank(content, fileChunks, topK, options);
|
||||
}
|
||||
|
||||
async matchWorkspaceDocs(
|
||||
@@ -218,11 +234,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
content: string,
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.5
|
||||
threshold: number = 0.5,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
) {
|
||||
const client = this.embeddingClient;
|
||||
if (!client) return [];
|
||||
const embedding = await client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(workspaceId, signal, routeContext);
|
||||
const embedding = await client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const workspaceChunks =
|
||||
@@ -234,7 +252,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
);
|
||||
if (!workspaceChunks.length) return [];
|
||||
|
||||
return await client.reRank(content, workspaceChunks, topK, signal);
|
||||
return await client.reRank(content, workspaceChunks, topK, options);
|
||||
}
|
||||
|
||||
async matchWorkspaceAll(
|
||||
@@ -244,11 +262,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
signal?: AbortSignal,
|
||||
threshold: number = 0.8,
|
||||
docIds?: string[],
|
||||
scopedThreshold: number = 0.85
|
||||
scopedThreshold: number = 0.85,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
) {
|
||||
const client = this.embeddingClient;
|
||||
if (!client) return [];
|
||||
const embedding = await client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(workspaceId, signal, routeContext);
|
||||
const embedding = await client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const [fileChunks, blobChunks, workspaceChunks, scopedWorkspaceChunks] =
|
||||
@@ -300,7 +320,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
|
||||
...(scopedWorkspaceChunks || []),
|
||||
],
|
||||
topK,
|
||||
signal
|
||||
options
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,11 @@ import {
|
||||
FileChunkSimilarity,
|
||||
Models,
|
||||
} from '../../../models';
|
||||
import { EmbeddingClient } from '../embedding/types';
|
||||
import type {
|
||||
EmbeddingCallOptions,
|
||||
EmbeddingClient,
|
||||
EmbeddingRouteContext,
|
||||
} from '../embedding/types';
|
||||
|
||||
export class ContextSession implements AsyncDisposable {
|
||||
constructor(
|
||||
@@ -69,6 +73,18 @@ export class ContextSession implements AsyncDisposable {
|
||||
);
|
||||
}
|
||||
|
||||
private embeddingOptions(
|
||||
signal?: AbortSignal,
|
||||
routeContext: EmbeddingRouteContext = {}
|
||||
): EmbeddingCallOptions {
|
||||
return {
|
||||
workspaceId: this.workspaceId,
|
||||
signal,
|
||||
...routeContext,
|
||||
featureKind: 'embedding',
|
||||
};
|
||||
}
|
||||
|
||||
async addCategoryRecord(type: ContextCategories, id: string, docs: string[]) {
|
||||
const category = this.config.categories.find(
|
||||
c => c.type === type && c.id === id
|
||||
@@ -269,10 +285,12 @@ export class ContextSession implements AsyncDisposable {
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
scopedThreshold: number = 0.85,
|
||||
threshold: number = 0.5
|
||||
threshold: number = 0.5,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
): Promise<FileChunkSimilarity[]> {
|
||||
if (!this.client) return [];
|
||||
const embedding = await this.client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(signal, routeContext);
|
||||
const embedding = await this.client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const [context, workspace] = await Promise.all([
|
||||
@@ -305,7 +323,7 @@ export class ContextSession implements AsyncDisposable {
|
||||
...workspace,
|
||||
],
|
||||
topK,
|
||||
signal
|
||||
options
|
||||
);
|
||||
}
|
||||
|
||||
@@ -322,10 +340,12 @@ export class ContextSession implements AsyncDisposable {
|
||||
topK: number = 5,
|
||||
signal?: AbortSignal,
|
||||
scopedThreshold: number = 0.85,
|
||||
threshold: number = 0.5
|
||||
threshold: number = 0.5,
|
||||
routeContext?: EmbeddingRouteContext
|
||||
) {
|
||||
if (!this.client) return [];
|
||||
const embedding = await this.client.getEmbedding(content, signal);
|
||||
const options = this.embeddingOptions(signal, routeContext);
|
||||
const embedding = await this.client.getEmbedding(content, options);
|
||||
if (!embedding) return [];
|
||||
|
||||
const docIds = this.docIds;
|
||||
@@ -349,7 +369,7 @@ export class ContextSession implements AsyncDisposable {
|
||||
content,
|
||||
[...inContext, ...workspace],
|
||||
topK,
|
||||
signal
|
||||
options
|
||||
);
|
||||
|
||||
// sort result, doc recorded in context first
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
|
||||
import { CopilotQuotaExceeded } from '../../../base';
|
||||
import { QuotaService } from '../../../core/quota';
|
||||
import { QuotaService } from '../../../core/quota/service';
|
||||
import { Models } from '../../../models';
|
||||
import type { Turn } from '../core';
|
||||
import type { ResolvedPrompt } from '../prompt';
|
||||
@@ -31,12 +31,16 @@ export class ConversationPolicy {
|
||||
}
|
||||
|
||||
async checkQuota(userId: string) {
|
||||
const { limit, used } = await this.getQuota(userId);
|
||||
if (limit && Number.isFinite(limit) && used >= limit) {
|
||||
if (!(await this.hasQuota(userId))) {
|
||||
throw new CopilotQuotaExceeded();
|
||||
}
|
||||
}
|
||||
|
||||
async hasQuota(userId: string) {
|
||||
const { limit, used } = await this.getQuota(userId);
|
||||
return !(limit !== undefined && Number.isFinite(limit) && used >= limit);
|
||||
}
|
||||
|
||||
shouldScheduleTitle(prompt: Pick<ResolvedPrompt, 'action'>) {
|
||||
return !prompt.action;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,12 @@ import {
|
||||
import { type CopilotRerankRequest } from '../providers/types';
|
||||
import { CapabilityRuntime } from '../runtime/capability-runtime';
|
||||
import { TaskPolicy } from '../runtime/task-policy';
|
||||
import { EmbeddingClient, type ReRankResult } from './types';
|
||||
import {
|
||||
type EmbeddingCallOptionsInput,
|
||||
EmbeddingClient,
|
||||
normalizeEmbeddingCallOptions,
|
||||
type ReRankResult,
|
||||
} from './types';
|
||||
|
||||
class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
private readonly logger = new Logger(ProductionEmbeddingClient.name);
|
||||
@@ -35,10 +40,19 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
return result;
|
||||
}
|
||||
|
||||
async getEmbeddings(input: string[]): Promise<Embedding[]> {
|
||||
async getEmbeddings(
|
||||
input: string[],
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<Embedding[]> {
|
||||
const normalizedOptions = normalizeEmbeddingCallOptions(options);
|
||||
const modelId = this.taskPolicy.resolveEmbeddingModelId();
|
||||
const embeddings = await this.runtime.embed(modelId, input, {
|
||||
dimensions: EMBEDDING_DIMENSIONS,
|
||||
signal: normalizedOptions.signal,
|
||||
user: normalizedOptions.userId,
|
||||
workspace: normalizedOptions.workspaceId,
|
||||
byokLeaseId: normalizedOptions.byokLeaseId,
|
||||
featureKind: normalizedOptions.featureKind ?? 'embedding',
|
||||
});
|
||||
if (embeddings.length !== input.length) {
|
||||
throw new CopilotFailedToGenerateEmbedding({
|
||||
@@ -67,8 +81,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
>(
|
||||
query: string,
|
||||
embeddings: Chunk[],
|
||||
signal?: AbortSignal
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<ReRankResult> {
|
||||
const normalizedOptions = normalizeEmbeddingCallOptions(options);
|
||||
if (!embeddings.length) return [];
|
||||
|
||||
const rerankRequest: CopilotRerankRequest = {
|
||||
@@ -82,7 +97,13 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
const ranks = await this.runtime.rerank(
|
||||
this.taskPolicy.resolveRerankModelId(),
|
||||
rerankRequest,
|
||||
{ signal }
|
||||
{
|
||||
signal: normalizedOptions.signal,
|
||||
user: normalizedOptions.userId,
|
||||
workspace: normalizedOptions.workspaceId,
|
||||
byokLeaseId: normalizedOptions.byokLeaseId,
|
||||
featureKind: 'rerank',
|
||||
}
|
||||
);
|
||||
|
||||
try {
|
||||
@@ -105,8 +126,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
query: string,
|
||||
embeddings: Chunk[],
|
||||
topK: number,
|
||||
signal?: AbortSignal
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<Chunk[]> {
|
||||
const normalizedOptions = normalizeEmbeddingCallOptions(options);
|
||||
// search in context and workspace may find same chunks, de-duplicate them
|
||||
const { deduped: dedupedEmbeddings } = embeddings.reduce(
|
||||
(acc, e) => {
|
||||
@@ -138,14 +160,19 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
const ranks = await this.getEmbeddingRelevance(
|
||||
query,
|
||||
sortedEmbeddings,
|
||||
signal
|
||||
normalizedOptions
|
||||
);
|
||||
if (sortedEmbeddings.length !== ranks.length) {
|
||||
// llm return wrong result, fallback to default sorting
|
||||
this.logger.warn(
|
||||
`Batch size mismatch: expected ${sortedEmbeddings.length}, got ${ranks.length}`
|
||||
);
|
||||
return await super.reRank(query, dedupedEmbeddings, topK, signal);
|
||||
return await super.reRank(
|
||||
query,
|
||||
dedupedEmbeddings,
|
||||
topK,
|
||||
normalizedOptions
|
||||
);
|
||||
}
|
||||
|
||||
const highConfidenceChunks = ranks
|
||||
@@ -164,7 +191,12 @@ class ProductionEmbeddingClient extends EmbeddingClient {
|
||||
return highConfidenceChunks.slice(0, topK);
|
||||
} catch (error) {
|
||||
this.logger.warn('ReRank failed, falling back to default sorting', error);
|
||||
return await super.reRank(query, dedupedEmbeddings, topK, signal);
|
||||
return await super.reRank(
|
||||
query,
|
||||
dedupedEmbeddings,
|
||||
topK,
|
||||
normalizedOptions
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,7 +212,8 @@ export class CopilotEmbeddingClientService {
|
||||
|
||||
async refresh() {
|
||||
const client = new ProductionEmbeddingClient(this.taskPolicy, this.runtime);
|
||||
this.client = (await client.configured()) ? client : undefined;
|
||||
await client.configured();
|
||||
this.client = client;
|
||||
return this.client;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import { Models } from '../../../models';
|
||||
import { CopilotStorage } from '../storage';
|
||||
import { readStream } from '../utils';
|
||||
import { CopilotEmbeddingClientService } from './client';
|
||||
import type { Chunk, DocFragment } from './types';
|
||||
import type { Chunk, DocFragment, EmbeddingCallOptions } from './types';
|
||||
import { EmbeddingClient } from './types';
|
||||
|
||||
@Injectable()
|
||||
@@ -242,6 +242,19 @@ export class CopilotEmbeddingJob {
|
||||
return new File([buffer], fileName);
|
||||
}
|
||||
|
||||
private workspaceIndexingOptions(
|
||||
workspaceId: string,
|
||||
signal?: AbortSignal,
|
||||
userId?: string
|
||||
): EmbeddingCallOptions {
|
||||
return {
|
||||
workspaceId,
|
||||
userId,
|
||||
signal,
|
||||
featureKind: 'workspace_indexing',
|
||||
};
|
||||
}
|
||||
|
||||
@OnJob('copilot.embedding.files')
|
||||
async embedPendingFile({
|
||||
userId,
|
||||
@@ -266,7 +279,10 @@ export class CopilotEmbeddingJob {
|
||||
const total = chunks.reduce((acc, c) => acc + c.length, 0);
|
||||
|
||||
for (const chunk of chunks) {
|
||||
const embeddings = await this.embeddingClient.generateEmbeddings(chunk);
|
||||
const embeddings = await this.embeddingClient.generateEmbeddings(
|
||||
chunk,
|
||||
this.workspaceIndexingOptions(workspaceId, undefined, userId)
|
||||
);
|
||||
if (contextId) {
|
||||
// for context files
|
||||
await this.models.copilotContext.insertFileEmbedding(
|
||||
@@ -320,7 +336,10 @@ export class CopilotEmbeddingJob {
|
||||
const total = chunks.reduce((acc, c) => acc + c.length, 0);
|
||||
|
||||
for (const chunk of chunks) {
|
||||
const embeddings = await this.embeddingClient.generateEmbeddings(chunk);
|
||||
const embeddings = await this.embeddingClient.generateEmbeddings(
|
||||
chunk,
|
||||
this.workspaceIndexingOptions(workspaceId)
|
||||
);
|
||||
await this.models.copilotWorkspace.insertBlobEmbeddings(
|
||||
workspaceId,
|
||||
blobId,
|
||||
@@ -462,7 +481,7 @@ export class CopilotEmbeddingJob {
|
||||
`${fragment.title || 'Untitled'}.md`
|
||||
),
|
||||
chunks => this.formatDocChunks(chunks, fragment),
|
||||
signal
|
||||
this.workspaceIndexingOptions(workspaceId, signal)
|
||||
);
|
||||
|
||||
for (const chunks of embeddings) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import { CopilotContextFileNotSupported } from '../../../base';
|
||||
import type { PageDocContent } from '../../../core/utils/blocksuite';
|
||||
import { ChunkSimilarity, Embedding } from '../../../models';
|
||||
import { parseDoc } from '../../../native';
|
||||
import type { ByokFeatureKind } from '../byok/types';
|
||||
|
||||
declare global {
|
||||
interface Events {
|
||||
@@ -103,6 +104,35 @@ export type Chunk = {
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type EmbeddingCallOptions = {
|
||||
signal?: AbortSignal;
|
||||
userId?: string;
|
||||
workspaceId?: string;
|
||||
byokLeaseId?: string;
|
||||
featureKind?: Extract<
|
||||
ByokFeatureKind,
|
||||
'embedding' | 'workspace_indexing' | 'rerank'
|
||||
>;
|
||||
};
|
||||
|
||||
export type EmbeddingCallOptionsInput = AbortSignal | EmbeddingCallOptions;
|
||||
export type EmbeddingRouteContext = Pick<
|
||||
EmbeddingCallOptions,
|
||||
'userId' | 'byokLeaseId'
|
||||
>;
|
||||
|
||||
export function normalizeEmbeddingCallOptions(
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): EmbeddingCallOptions {
|
||||
if (!options) {
|
||||
return {};
|
||||
}
|
||||
if ('aborted' in options && 'addEventListener' in options) {
|
||||
return { signal: options };
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
export abstract class EmbeddingClient {
|
||||
async configured() {
|
||||
return true;
|
||||
@@ -111,11 +141,14 @@ export abstract class EmbeddingClient {
|
||||
async getFileEmbeddings(
|
||||
file: File,
|
||||
chunkMapper: (chunk: Chunk[]) => Chunk[],
|
||||
signal?: AbortSignal
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<Embedding[][]> {
|
||||
const chunks = await this.getFileChunks(file, signal);
|
||||
const normalizedOptions = normalizeEmbeddingCallOptions(options);
|
||||
const chunks = await this.getFileChunks(file, normalizedOptions.signal);
|
||||
const chunkedEmbeddings = await Promise.all(
|
||||
chunks.map(chunk => this.generateEmbeddings(chunkMapper(chunk)))
|
||||
chunks.map(chunk =>
|
||||
this.generateEmbeddings(chunkMapper(chunk), normalizedOptions)
|
||||
)
|
||||
);
|
||||
return chunkedEmbeddings;
|
||||
}
|
||||
@@ -154,8 +187,9 @@ export abstract class EmbeddingClient {
|
||||
|
||||
async generateEmbeddings(
|
||||
chunks: Chunk[],
|
||||
signal?: AbortSignal
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<Embedding[]> {
|
||||
const normalizedOptions = normalizeEmbeddingCallOptions(options);
|
||||
const retry = 3;
|
||||
|
||||
let embeddings: Embedding[] = [];
|
||||
@@ -164,7 +198,7 @@ export abstract class EmbeddingClient {
|
||||
try {
|
||||
embeddings = await this.getEmbeddings(
|
||||
chunks.map(c => c.content),
|
||||
signal
|
||||
normalizedOptions
|
||||
);
|
||||
break;
|
||||
} catch (e) {
|
||||
@@ -181,7 +215,7 @@ export abstract class EmbeddingClient {
|
||||
_query: string,
|
||||
embeddings: Chunk[],
|
||||
topK: number,
|
||||
_signal?: AbortSignal
|
||||
_options?: EmbeddingCallOptionsInput
|
||||
): Promise<Chunk[]> {
|
||||
// sort by distance with ascending order
|
||||
return embeddings
|
||||
@@ -189,14 +223,14 @@ export abstract class EmbeddingClient {
|
||||
.slice(0, topK);
|
||||
}
|
||||
|
||||
async getEmbedding(query: string, signal?: AbortSignal) {
|
||||
const embedding = await this.getEmbeddings([query], signal);
|
||||
async getEmbedding(query: string, options?: EmbeddingCallOptionsInput) {
|
||||
const embedding = await this.getEmbeddings([query], options);
|
||||
return embedding?.[0]?.embedding;
|
||||
}
|
||||
|
||||
abstract getEmbeddings(
|
||||
input: string[],
|
||||
signal?: AbortSignal
|
||||
options?: EmbeddingCallOptionsInput
|
||||
): Promise<Embedding[]>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
import { CopilotAccessPolicy } from './access';
|
||||
import {
|
||||
ByokEntitlementPolicy,
|
||||
ByokService,
|
||||
WorkspaceByokResolver,
|
||||
} from './byok';
|
||||
import { HistoryAttachmentUrlProjector } from './compat/history-attachment-url-projector';
|
||||
import { CompatHistoryProjector } from './compat/history-projector';
|
||||
import { HistoryPromptPreloadProjector } from './compat/history-prompt-preload-projector';
|
||||
@@ -64,10 +70,13 @@ export const COPILOT_PROVIDER_PROVIDERS = [
|
||||
];
|
||||
|
||||
export const COPILOT_RUNTIME_PROVIDERS = [
|
||||
ByokEntitlementPolicy,
|
||||
ByokService,
|
||||
ChatSessionService,
|
||||
ConversationStore,
|
||||
ConversationInboxService,
|
||||
ConversationPolicy,
|
||||
CopilotAccessPolicy,
|
||||
HistoryAttachmentUrlProjector,
|
||||
CompatHistoryProjector,
|
||||
HistoryPromptPreloadProjector,
|
||||
@@ -114,6 +123,7 @@ export const COPILOT_RESOLVER_PROVIDERS = [
|
||||
CopilotResolver,
|
||||
UserCopilotResolver,
|
||||
CopilotContextRootResolver,
|
||||
WorkspaceByokResolver,
|
||||
];
|
||||
|
||||
export const COPILOT_JOB_PROVIDERS = [CopilotEmbeddingJob, CopilotCronJobs];
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import { CopilotProviderSideError, UserFriendlyError } from '../../../base';
|
||||
import { type LlmBackendConfig } from '../../../native';
|
||||
import type { CopilotTool } from '../tools';
|
||||
import { CopilotProvider } from './provider';
|
||||
import {
|
||||
type CopilotProviderExecution,
|
||||
type ProviderDriverSpec,
|
||||
} from './provider-runtime-contract';
|
||||
import { type CopilotChatTools, CopilotProviderType } from './types';
|
||||
import { CopilotProviderType } from './types';
|
||||
|
||||
export type CloudflareWorkersAIConfig = {
|
||||
apiToken: string;
|
||||
@@ -25,16 +24,6 @@ export class CloudflareWorkersAIProvider extends CopilotProvider<CloudflareWorke
|
||||
const config = this.getConfig(execution);
|
||||
return !!config.apiToken && (!!config.accountId || !!config.baseURL);
|
||||
}
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
_model: string
|
||||
): [string, CopilotTool?] | undefined {
|
||||
if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { CopilotQuotaExceeded } from '../../../base';
|
||||
import { ServerFeature, ServerService } from '../../../core';
|
||||
import { type CopilotAccessContext, CopilotAccessPolicy } from '../access';
|
||||
import type { RequiredStructuredOutputContract } from '../runtime/contracts';
|
||||
import { getProviderRuntimeHost } from '../runtime/provider-runtime-context';
|
||||
import type { CopilotProvider } from './provider';
|
||||
import {
|
||||
buildProviderRegistry,
|
||||
type CopilotProviderRegistry,
|
||||
type NormalizedCopilotProviderProfile,
|
||||
resolveModel,
|
||||
stripProviderPrefix,
|
||||
@@ -57,11 +61,18 @@ type RoutePreparationResult = Partial<
|
||||
>
|
||||
>;
|
||||
|
||||
type EffectiveProviderRegistry = {
|
||||
byokRegistry: CopilotProviderRegistry;
|
||||
quotaBackedRegistry: CopilotProviderRegistry;
|
||||
quotaBackedRoutesAvailable: boolean;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class CopilotProviderFactory {
|
||||
constructor(
|
||||
private readonly server: ServerService,
|
||||
private readonly registries: CopilotProviderRegistryService
|
||||
private readonly registries: CopilotProviderRegistryService,
|
||||
private readonly access: CopilotAccessPolicy
|
||||
) {}
|
||||
|
||||
private readonly logger = new Logger(CopilotProviderFactory.name);
|
||||
@@ -73,20 +84,84 @@ export class CopilotProviderFactory {
|
||||
return this.registries.getRegistry();
|
||||
}
|
||||
|
||||
private getPreferredProviderIds(type?: CopilotProviderType) {
|
||||
private getProviderByProfile(
|
||||
providerId: string,
|
||||
profile: NormalizedCopilotProviderProfile
|
||||
) {
|
||||
return (
|
||||
this.#providers.get(providerId) ??
|
||||
Array.from(this.#providerIdsByType.get(profile.type) ?? [])
|
||||
.map(id => this.#providers.get(id))
|
||||
.find((provider): provider is CopilotProvider => !!provider)
|
||||
);
|
||||
}
|
||||
|
||||
private providerAvailable(
|
||||
providerId: string,
|
||||
profile: NormalizedCopilotProviderProfile
|
||||
) {
|
||||
return !!this.getProviderByProfile(providerId, profile);
|
||||
}
|
||||
|
||||
private getAvailableProviderIds(registry: CopilotProviderRegistry) {
|
||||
return Array.from(registry.profiles.entries())
|
||||
.filter(([providerId, profile]) =>
|
||||
this.providerAvailable(providerId, profile)
|
||||
)
|
||||
.map(([providerId]) => providerId);
|
||||
}
|
||||
|
||||
private getPreferredProviderIds(
|
||||
registry: CopilotProviderRegistry,
|
||||
type?: CopilotProviderType
|
||||
) {
|
||||
if (!type) return undefined;
|
||||
return this.#providerIdsByType.get(type);
|
||||
return registry.byType.get(type)?.filter(providerId => {
|
||||
const profile = registry.profiles.get(providerId);
|
||||
return profile ? this.providerAvailable(providerId, profile) : false;
|
||||
});
|
||||
}
|
||||
|
||||
private normalizeCond(
|
||||
registry: CopilotProviderRegistry,
|
||||
providerId: string,
|
||||
cond: ModelFullConditions
|
||||
): ModelFullConditions {
|
||||
const registry = this.getRegistry();
|
||||
const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
|
||||
return { ...cond, modelId };
|
||||
}
|
||||
|
||||
private async getEffectiveRegistry(
|
||||
context: CopilotAccessContext = {}
|
||||
): Promise<EffectiveProviderRegistry> {
|
||||
const quotaBackedRegistry = this.getRegistry();
|
||||
const routeAccess = await this.access.resolveRouteAccess(context);
|
||||
|
||||
return {
|
||||
byokRegistry: buildProviderRegistry({
|
||||
profiles: routeAccess.byokProfiles,
|
||||
defaults: {},
|
||||
}),
|
||||
quotaBackedRegistry,
|
||||
quotaBackedRoutesAvailable: routeAccess.quotaBackedRoutesAvailable,
|
||||
};
|
||||
}
|
||||
|
||||
private getRequestContext(
|
||||
options?:
|
||||
| CopilotChatOptions
|
||||
| CopilotStructuredOptions
|
||||
| CopilotImageOptions
|
||||
): CopilotAccessContext {
|
||||
return {
|
||||
userId: options?.user,
|
||||
workspaceId: options?.workspace,
|
||||
byokLeaseId: options?.byokLeaseId,
|
||||
featureKind: options?.featureKind,
|
||||
quotaBackedRoutesAllowed: options?.quotaBackedRoutesAllowed,
|
||||
};
|
||||
}
|
||||
|
||||
private filterPreparedRoutes(routes: Array<ResolvedCopilotProvider | null>) {
|
||||
return routes.filter(
|
||||
(route): route is ResolvedCopilotProvider => route !== null
|
||||
@@ -113,36 +188,89 @@ export class CopilotProviderFactory {
|
||||
cond: ModelFullConditions,
|
||||
filter: {
|
||||
prefer?: CopilotProviderType;
|
||||
} = {}
|
||||
} = {},
|
||||
context: CopilotAccessContext = {}
|
||||
): Promise<ResolvedCopilotProvider | null> {
|
||||
return (await this.resolveRoutes(cond, filter))[0] ?? null;
|
||||
return (await this.resolveRoutes(cond, filter, context))[0] ?? null;
|
||||
}
|
||||
|
||||
async resolveRoutes(
|
||||
cond: ModelFullConditions,
|
||||
filter: {
|
||||
prefer?: CopilotProviderType;
|
||||
} = {}
|
||||
} = {},
|
||||
context: CopilotAccessContext = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
this.logger.debug(
|
||||
`Resolving copilot provider for output type: ${cond.outputType}`
|
||||
);
|
||||
const registry = this.getRegistry();
|
||||
const { byokRegistry, quotaBackedRegistry, quotaBackedRoutesAvailable } =
|
||||
await this.getEffectiveRegistry(context);
|
||||
const byokRoutes = await this.resolveRoutesFromRegistry(
|
||||
byokRegistry,
|
||||
cond,
|
||||
filter
|
||||
);
|
||||
const resolved = byokRoutes.length
|
||||
? byokRoutes
|
||||
: quotaBackedRoutesAvailable
|
||||
? await this.resolveRoutesFromRegistry(
|
||||
quotaBackedRegistry,
|
||||
cond,
|
||||
filter
|
||||
)
|
||||
: [];
|
||||
for (const route of resolved) {
|
||||
this.logger.debug(
|
||||
`Copilot provider candidate found: ${route.provider.type} (${route.providerId})`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
!resolved.length &&
|
||||
!quotaBackedRoutesAvailable &&
|
||||
context.quotaBackedRoutesAllowed !== false
|
||||
) {
|
||||
const quotaBackedRoutes = await this.resolveRoutesFromRegistry(
|
||||
quotaBackedRegistry,
|
||||
cond,
|
||||
filter
|
||||
);
|
||||
if (quotaBackedRoutes.length) {
|
||||
throw new CopilotQuotaExceeded();
|
||||
}
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
private async resolveRoutesFromRegistry(
|
||||
registry: CopilotProviderRegistry,
|
||||
cond: ModelFullConditions,
|
||||
filter: {
|
||||
prefer?: CopilotProviderType;
|
||||
} = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const route = resolveModel({
|
||||
registry,
|
||||
modelId: cond.modelId,
|
||||
outputType: cond.outputType,
|
||||
availableProviderIds: this.#providers.keys(),
|
||||
preferredProviderIds: this.getPreferredProviderIds(filter.prefer),
|
||||
availableProviderIds: this.getAvailableProviderIds(registry),
|
||||
preferredProviderIds: this.getPreferredProviderIds(
|
||||
registry,
|
||||
filter.prefer
|
||||
),
|
||||
});
|
||||
|
||||
const resolved: ResolvedCopilotProvider[] = [];
|
||||
for (const providerId of route.candidateProviderIds) {
|
||||
const provider = this.#providers.get(providerId);
|
||||
const profile = registry.profiles.get(providerId);
|
||||
const provider = profile
|
||||
? this.getProviderByProfile(providerId, profile)
|
||||
: undefined;
|
||||
if (!provider || !profile) continue;
|
||||
|
||||
const normalizedCond = this.normalizeCond(providerId, cond);
|
||||
const normalizedCond = this.normalizeCond(registry, providerId, cond);
|
||||
if (
|
||||
normalizedCond.modelId &&
|
||||
profile.models?.length &&
|
||||
@@ -155,9 +283,6 @@ export class CopilotProviderFactory {
|
||||
const matched = await provider.match(normalizedCond, execution);
|
||||
if (!matched) continue;
|
||||
|
||||
this.logger.debug(
|
||||
`Copilot provider candidate found: ${provider.type} (${providerId})`
|
||||
);
|
||||
resolved.push({
|
||||
providerId,
|
||||
provider,
|
||||
@@ -181,7 +306,11 @@ export class CopilotProviderFactory {
|
||||
prefer?: CopilotProviderType;
|
||||
} = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const routes = await this.resolveRoutes(cond, filter);
|
||||
const routes = await this.resolveRoutes(
|
||||
cond,
|
||||
filter,
|
||||
this.getRequestContext(options)
|
||||
);
|
||||
return await this.prepareResolvedRoutes(routes, async route => {
|
||||
const prepared = await getProviderRuntimeHost(
|
||||
route.provider
|
||||
@@ -213,7 +342,11 @@ export class CopilotProviderFactory {
|
||||
} = {},
|
||||
responseContract?: RequiredStructuredOutputContract
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const routes = await this.resolveRoutes(cond, filter);
|
||||
const routes = await this.resolveRoutes(
|
||||
cond,
|
||||
filter,
|
||||
this.getRequestContext(options)
|
||||
);
|
||||
return await this.prepareResolvedRoutes(routes, async route => {
|
||||
const preparedStructured =
|
||||
(await getProviderRuntimeHost(route.provider).prepare.structured(
|
||||
@@ -239,10 +372,14 @@ export class CopilotProviderFactory {
|
||||
input: string | string[],
|
||||
options: CopilotEmbeddingOptions = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const routes = await this.resolveRoutes({
|
||||
modelId,
|
||||
outputType: ModelOutputType.Embedding,
|
||||
});
|
||||
const routes = await this.resolveRoutes(
|
||||
{ modelId, outputType: ModelOutputType.Embedding },
|
||||
{},
|
||||
{
|
||||
...this.getRequestContext(options),
|
||||
featureKind: options?.featureKind ?? 'embedding',
|
||||
}
|
||||
);
|
||||
return await this.prepareResolvedRoutes(routes, async route => {
|
||||
const preparedEmbedding =
|
||||
(await getProviderRuntimeHost(route.provider).prepare.embedding(
|
||||
@@ -267,10 +404,14 @@ export class CopilotProviderFactory {
|
||||
request: CopilotRerankRequest,
|
||||
options: CopilotChatOptions = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const routes = await this.resolveRoutes({
|
||||
const routes = await this.resolveRoutes(
|
||||
{
|
||||
modelId,
|
||||
outputType: ModelOutputType.Rerank,
|
||||
});
|
||||
},
|
||||
{},
|
||||
{ ...this.getRequestContext(options), featureKind: 'rerank' }
|
||||
);
|
||||
return await this.prepareResolvedRoutes(routes, async route => {
|
||||
const preparedRerank =
|
||||
(await getProviderRuntimeHost(route.provider).prepare.rerank(
|
||||
@@ -298,7 +439,10 @@ export class CopilotProviderFactory {
|
||||
prefer?: CopilotProviderType;
|
||||
} = {}
|
||||
): Promise<ResolvedCopilotProvider[]> {
|
||||
const routes = await this.resolveRoutes(cond, filter);
|
||||
const routes = await this.resolveRoutes(cond, filter, {
|
||||
...this.getRequestContext(options),
|
||||
featureKind: options?.featureKind ?? 'image',
|
||||
});
|
||||
return await this.prepareResolvedRoutes(routes, async route => {
|
||||
const preparedImage =
|
||||
(await getProviderRuntimeHost(route.provider).prepare.image(
|
||||
|
||||
@@ -8,7 +8,6 @@ export { FalProvider } from './fal';
|
||||
export { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini';
|
||||
export { CopilotProviderLifecycleService } from './lifecycle-service';
|
||||
export { OpenAIProvider } from './openai';
|
||||
export { PerplexityProvider } from './perplexity';
|
||||
export type { CopilotProvider } from './provider';
|
||||
export { CopilotProviders } from './provider-tokens';
|
||||
export { CopilotProviderRegistryService } from './registry-service';
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import { CopilotProviderSideError, UserFriendlyError } from '../../../base';
|
||||
import { type LlmBackendConfig } from '../../../native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import {
|
||||
type CopilotProviderExecution,
|
||||
type ProviderDriverSpec,
|
||||
} from './provider-runtime-contract';
|
||||
import { CopilotProviderType, ModelOutputType } from './types';
|
||||
|
||||
export const DEFAULT_DIMENSIONS = 256;
|
||||
|
||||
export type MorphConfig = {
|
||||
apiKey?: string;
|
||||
};
|
||||
|
||||
export class MorphProvider extends CopilotProvider<MorphConfig> {
|
||||
readonly type = CopilotProviderType.Morph;
|
||||
|
||||
protected resolveModelBackendKind() {
|
||||
return 'morph' as const;
|
||||
}
|
||||
|
||||
override configured(execution?: CopilotProviderExecution): boolean {
|
||||
return !!this.getConfig(execution).apiKey;
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
return e;
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected morph response',
|
||||
});
|
||||
}
|
||||
|
||||
private createNativeConfig(
|
||||
execution?: CopilotProviderExecution
|
||||
): LlmBackendConfig {
|
||||
return {
|
||||
base_url: 'https://api.morphllm.com',
|
||||
auth_token: this.getConfig(execution).apiKey ?? '',
|
||||
};
|
||||
}
|
||||
|
||||
override getDriverSpec(): ProviderDriverSpec {
|
||||
return {
|
||||
createBackendConfig: execution => this.createNativeConfig(execution),
|
||||
mapError: error => this.handleError(error),
|
||||
chat: {
|
||||
resolveOutputType: kind =>
|
||||
kind === 'streamObject' ? null : ModelOutputType.Text,
|
||||
},
|
||||
structured: false,
|
||||
embedding: false,
|
||||
rerank: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import {
|
||||
AttachmentAdmissionHost,
|
||||
} from '../runtime/hosts/attachment-admission';
|
||||
import { AttachmentMaterializer } from '../runtime/hosts/attachment-materializer';
|
||||
import type { CopilotTool } from '../tools';
|
||||
import { CopilotProvider } from './provider';
|
||||
import { hasProviderModelBehaviorFlag } from './provider-model-runtime';
|
||||
import type {
|
||||
@@ -22,7 +21,6 @@ import type {
|
||||
ProviderDriverSpec,
|
||||
} from './provider-runtime-contract';
|
||||
import {
|
||||
CopilotChatTools,
|
||||
CopilotProviderType,
|
||||
type PromptAttachment,
|
||||
type PromptMessage,
|
||||
@@ -64,16 +62,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
|
||||
});
|
||||
}
|
||||
|
||||
override getProviderSpecificTools(
|
||||
toolName: CopilotChatTools,
|
||||
_model: string
|
||||
): [string, CopilotTool?] | undefined {
|
||||
if (toolName === 'docEdit') {
|
||||
return ['doc_edit', undefined];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
protected createNativeConfig(
|
||||
execution?: CopilotProviderExecution
|
||||
): LlmBackendConfig {
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
import { CopilotProviderSideError } from '../../../base';
|
||||
import { type LlmBackendConfig } from '../../../native';
|
||||
import { CopilotProvider } from './provider';
|
||||
import { hasProviderModelBehaviorFlag } from './provider-model-runtime';
|
||||
import {
|
||||
type CopilotProviderExecution,
|
||||
type ProviderDriverSpec,
|
||||
} from './provider-runtime-contract';
|
||||
import { CopilotProviderType, ModelOutputType } from './types';
|
||||
|
||||
export type PerplexityConfig = {
|
||||
apiKey: string;
|
||||
endpoint?: string;
|
||||
};
|
||||
|
||||
export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
|
||||
readonly type = CopilotProviderType.Perplexity;
|
||||
|
||||
protected resolveModelBackendKind() {
|
||||
return 'perplexity' as const;
|
||||
}
|
||||
|
||||
override configured(execution?: CopilotProviderExecution): boolean {
|
||||
return !!this.getConfig(execution).apiKey;
|
||||
}
|
||||
|
||||
override getDriverSpec(): ProviderDriverSpec {
|
||||
return {
|
||||
createBackendConfig: execution => this.createNativeConfig(execution),
|
||||
mapError: error => this.handleError(error),
|
||||
chat: {
|
||||
resolveOutputType: kind =>
|
||||
kind === 'streamObject' ? null : ModelOutputType.Text,
|
||||
withAttachment: false,
|
||||
resolveRequestOptions: async context => ({
|
||||
withAttachment: !hasProviderModelBehaviorFlag(
|
||||
context.model,
|
||||
'no_attachments'
|
||||
),
|
||||
include: hasProviderModelBehaviorFlag(
|
||||
context.model,
|
||||
'citations_include'
|
||||
)
|
||||
? ['citations']
|
||||
: undefined,
|
||||
}),
|
||||
},
|
||||
structured: false,
|
||||
embedding: false,
|
||||
rerank: false,
|
||||
};
|
||||
}
|
||||
|
||||
private createNativeConfig(
|
||||
execution?: CopilotProviderExecution
|
||||
): LlmBackendConfig {
|
||||
const config = this.getConfig(execution);
|
||||
const baseUrl = config.endpoint || 'https://api.perplexity.ai';
|
||||
return {
|
||||
base_url: baseUrl.replace(/\/v1\/?$/, ''),
|
||||
auth_token: config.apiKey,
|
||||
};
|
||||
}
|
||||
|
||||
private handleError(e: any) {
|
||||
if (e instanceof CopilotProviderSideError) {
|
||||
return e;
|
||||
}
|
||||
return new CopilotProviderSideError({
|
||||
provider: this.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected perplexity response',
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -21,18 +21,6 @@ const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
|
||||
[CopilotProviderType.AnthropicVertex]: {
|
||||
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
|
||||
},
|
||||
[CopilotProviderType.Morph]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
},
|
||||
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
|
||||
},
|
||||
[CopilotProviderType.Perplexity]: {
|
||||
rust: {
|
||||
request: ['clamp_max_tokens'],
|
||||
},
|
||||
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
|
||||
},
|
||||
[CopilotProviderType.Gemini]: {
|
||||
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
|
||||
},
|
||||
|
||||
@@ -15,10 +15,8 @@ const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [
|
||||
CopilotProviderType.FAL,
|
||||
CopilotProviderType.Gemini,
|
||||
CopilotProviderType.GeminiVertex,
|
||||
CopilotProviderType.Perplexity,
|
||||
CopilotProviderType.Anthropic,
|
||||
CopilotProviderType.AnthropicVertex,
|
||||
CopilotProviderType.Morph,
|
||||
];
|
||||
|
||||
const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce(
|
||||
|
||||
@@ -5,9 +5,7 @@ import {
|
||||
import { CloudflareWorkersAIProvider } from './cloudflare';
|
||||
import { FalProvider } from './fal';
|
||||
import { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini';
|
||||
import { MorphProvider } from './morph';
|
||||
import { OpenAIProvider } from './openai';
|
||||
import { PerplexityProvider } from './perplexity';
|
||||
|
||||
export const CopilotProviders = [
|
||||
OpenAIProvider,
|
||||
@@ -15,8 +13,6 @@ export const CopilotProviders = [
|
||||
FalProvider,
|
||||
GeminiGenerativeProvider,
|
||||
GeminiVertexProvider,
|
||||
PerplexityProvider,
|
||||
AnthropicOfficialProvider,
|
||||
AnthropicVertexProvider,
|
||||
MorphProvider,
|
||||
];
|
||||
|
||||
@@ -30,8 +30,6 @@ export enum CopilotProviderType {
|
||||
Gemini = 'gemini',
|
||||
GeminiVertex = 'geminiVertex',
|
||||
OpenAI = 'openai',
|
||||
Perplexity = 'perplexity',
|
||||
Morph = 'morph',
|
||||
}
|
||||
|
||||
export const CopilotProviderSchema = z.object({
|
||||
@@ -80,8 +78,6 @@ export const PromptToolsSchema = z
|
||||
'blobRead',
|
||||
'codeArtifact',
|
||||
'conversationSummary',
|
||||
// work with morph
|
||||
'docEdit',
|
||||
// work with indexer
|
||||
'docRead',
|
||||
'docCreate',
|
||||
@@ -268,6 +264,22 @@ const CopilotProviderOptionsSchema = z.object({
|
||||
user: z.string().optional(),
|
||||
session: z.string().optional(),
|
||||
workspace: z.string().optional(),
|
||||
byokLeaseId: z.string().optional(),
|
||||
billingUnitId: z.string().optional(),
|
||||
taskId: z.string().optional(),
|
||||
actionId: z.string().optional(),
|
||||
quotaBackedRoutesAllowed: z.boolean().optional(),
|
||||
featureKind: z
|
||||
.enum([
|
||||
'chat',
|
||||
'action',
|
||||
'image',
|
||||
'embedding',
|
||||
'workspace_indexing',
|
||||
'rerank',
|
||||
'transcript',
|
||||
])
|
||||
.optional(),
|
||||
});
|
||||
|
||||
export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(
|
||||
|
||||
@@ -164,11 +164,6 @@ export function toError(error: unknown): Error {
|
||||
}
|
||||
}
|
||||
|
||||
type DocEditFootnote = {
|
||||
intent: string;
|
||||
result: string;
|
||||
};
|
||||
|
||||
function asRecord(value: unknown): Record<string, unknown> | null {
|
||||
if (value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
return value as Record<string, unknown>;
|
||||
@@ -184,8 +179,6 @@ export class TextStreamParser {
|
||||
|
||||
private prefix: string | null = this.CALLOUT_PREFIX;
|
||||
|
||||
private readonly docEditFootnotes: DocEditFootnote[] = [];
|
||||
|
||||
public parse(chunk: CopilotTextStreamPart) {
|
||||
let result = '';
|
||||
switch (chunk.type) {
|
||||
@@ -233,13 +226,6 @@ export class TextStreamParser {
|
||||
result += `\nWriting document "${chunk.input.title}"\n`;
|
||||
break;
|
||||
}
|
||||
case 'doc_edit': {
|
||||
this.docEditFootnotes.push({
|
||||
intent: String(chunk.input.instructions ?? ''),
|
||||
result: '',
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
result = this.markAsCallout(result);
|
||||
break;
|
||||
@@ -250,22 +236,6 @@ export class TextStreamParser {
|
||||
);
|
||||
result = this.addPrefix(result);
|
||||
switch (chunk.toolName) {
|
||||
case 'doc_edit': {
|
||||
const output = asRecord(chunk.output);
|
||||
const array = output?.result;
|
||||
if (Array.isArray(array)) {
|
||||
result += array
|
||||
.map(item => {
|
||||
return `\n${String(asRecord(item)?.changedContent ?? '')}\n`;
|
||||
})
|
||||
.join('');
|
||||
this.docEditFootnotes[this.docEditFootnotes.length - 1].result =
|
||||
result;
|
||||
} else {
|
||||
this.docEditFootnotes.pop();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'doc_semantic_search': {
|
||||
const output = chunk.output;
|
||||
if (Array.isArray(output)) {
|
||||
@@ -319,10 +289,7 @@ export class TextStreamParser {
|
||||
}
|
||||
|
||||
public end() {
|
||||
const footnotes = this.docEditFootnotes.map((footnote, index) => {
|
||||
return `[^edit${index + 1}]: ${JSON.stringify({ type: 'doc-edit', ...footnote })}`;
|
||||
});
|
||||
return footnotes.join('\n');
|
||||
return '';
|
||||
}
|
||||
|
||||
private addPrefix(text: string) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { BadRequestException, NotFoundException } from '@nestjs/common';
|
||||
import { NotFoundException } from '@nestjs/common';
|
||||
import {
|
||||
Args,
|
||||
Field,
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
Mutation,
|
||||
ObjectType,
|
||||
Parent,
|
||||
Query,
|
||||
registerEnumType,
|
||||
ResolveField,
|
||||
Resolver,
|
||||
@@ -19,7 +18,6 @@ import {
|
||||
CallMetric,
|
||||
CopilotDocNotFound,
|
||||
CopilotFailedToCreateMessage,
|
||||
CopilotProviderSideError,
|
||||
CopilotSessionNotFound,
|
||||
type FileUpload,
|
||||
paginate,
|
||||
@@ -28,10 +26,8 @@ import {
|
||||
RequestMutex,
|
||||
Throttle,
|
||||
TooManyRequest,
|
||||
UserFriendlyError,
|
||||
} from '../../base';
|
||||
import { CurrentUser } from '../../core/auth';
|
||||
import { DocReader } from '../../core/doc';
|
||||
import { AccessController, DocAction } from '../../core/permission';
|
||||
import { UserType } from '../../core/user';
|
||||
import type { ListSessionOptions, UpdateChatSession } from '../../models';
|
||||
@@ -40,7 +36,6 @@ import { ConversationInboxService } from './conversation/inbox';
|
||||
import { PromptService } from './prompt/service';
|
||||
import { CopilotProviderFactory } from './providers/factory';
|
||||
import { ModelOutputType, type StreamObject } from './providers/types';
|
||||
import { CapabilityRuntime } from './runtime/capability-runtime';
|
||||
import { ChatSessionService } from './session';
|
||||
import { type ChatHistory, type ChatMessage, SubmittedMessage } from './types';
|
||||
|
||||
@@ -376,9 +371,7 @@ export class CopilotResolver {
|
||||
private readonly chatSession: ChatSessionService,
|
||||
private readonly historyProjector: CompatHistoryProjector,
|
||||
private readonly inbox: ConversationInboxService,
|
||||
private readonly docReader: DocReader,
|
||||
private readonly providerFactory: CopilotProviderFactory,
|
||||
private readonly runtime: CapabilityRuntime
|
||||
private readonly providerFactory: CopilotProviderFactory
|
||||
) {}
|
||||
|
||||
@ResolveField(() => CopilotQuotaType, {
|
||||
@@ -641,8 +634,6 @@ export class CopilotResolver {
|
||||
throw new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
return await this.chatSession.create({
|
||||
...options,
|
||||
pinned: options.pinned ?? false,
|
||||
@@ -724,7 +715,6 @@ export class CopilotResolver {
|
||||
throw new TooManyRequest('Server is busy');
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
return await this.chatSession.update({
|
||||
...options,
|
||||
userId: user.id,
|
||||
@@ -752,8 +742,6 @@ export class CopilotResolver {
|
||||
throw new CopilotDocNotFound({ docId: options.docId });
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
|
||||
return await this.chatSession.fork({
|
||||
...options,
|
||||
userId: user.id,
|
||||
@@ -819,96 +807,6 @@ export class CopilotResolver {
|
||||
}
|
||||
}
|
||||
|
||||
@Query(() => String, {
|
||||
description:
|
||||
'Apply updates to a doc using LLM and return the merged markdown.',
|
||||
deprecationReason: 'use Mutation.applyDocUpdates',
|
||||
})
|
||||
async applyDocUpdates(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'workspaceId', type: () => String })
|
||||
workspaceId: string,
|
||||
@Args({ name: 'docId', type: () => String })
|
||||
docId: string,
|
||||
@Args({ name: 'op', type: () => String })
|
||||
op: string,
|
||||
@Args({ name: 'updates', type: () => String })
|
||||
updates: string
|
||||
): Promise<string> {
|
||||
return this.applyDocUpdatesInternal(user, workspaceId, docId, op, updates);
|
||||
}
|
||||
|
||||
@Mutation(() => String, {
|
||||
description:
|
||||
'Apply updates to a doc using LLM and return the merged markdown.',
|
||||
name: 'applyDocUpdates',
|
||||
})
|
||||
async applyDocUpdatesMutation(
|
||||
@CurrentUser() user: CurrentUser,
|
||||
@Args({ name: 'workspaceId', type: () => String })
|
||||
workspaceId: string,
|
||||
@Args({ name: 'docId', type: () => String })
|
||||
docId: string,
|
||||
@Args({ name: 'op', type: () => String })
|
||||
op: string,
|
||||
@Args({ name: 'updates', type: () => String })
|
||||
updates: string
|
||||
): Promise<string> {
|
||||
return this.applyDocUpdatesInternal(user, workspaceId, docId, op, updates);
|
||||
}
|
||||
|
||||
private async applyDocUpdatesInternal(
|
||||
user: CurrentUser,
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
op: string,
|
||||
updates: string
|
||||
): Promise<string> {
|
||||
await this.assertPermission(user, { workspaceId, docId });
|
||||
|
||||
const docContent = await this.docReader.getDocMarkdown(
|
||||
workspaceId,
|
||||
docId,
|
||||
true
|
||||
);
|
||||
if (!docContent || !docContent.markdown) {
|
||||
throw new NotFoundException('Doc not found or empty');
|
||||
}
|
||||
|
||||
const markdown = docContent.markdown.trim();
|
||||
|
||||
const resolved = await this.providerFactory.resolveProvider({
|
||||
modelId: 'morph-v3-large',
|
||||
outputType: ModelOutputType.Text,
|
||||
});
|
||||
if (!resolved) {
|
||||
throw new BadRequestException('No LLM provider available');
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.runtime.text(
|
||||
{ modelId: 'morph-v3-large' },
|
||||
[
|
||||
{
|
||||
role: 'user',
|
||||
content: `<instruction>${op}</instruction>\n<code>${markdown}</code>\n<update>${updates}</update>`,
|
||||
},
|
||||
],
|
||||
{ reasoning: false }
|
||||
);
|
||||
} catch (e: any) {
|
||||
if (e instanceof UserFriendlyError) {
|
||||
throw e;
|
||||
} else {
|
||||
throw new CopilotProviderSideError({
|
||||
provider: resolved.provider.type,
|
||||
kind: 'unexpected_response',
|
||||
message: e?.message || 'Unexpected apply response',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private transformToSessionType(session: Omit<ChatHistory, 'messages'>) {
|
||||
return { id: session.sessionId, ...session };
|
||||
}
|
||||
|
||||
@@ -269,19 +269,20 @@ export class ActionRuntimeBridge {
|
||||
attempt,
|
||||
});
|
||||
|
||||
const inputWithBillingUnit = this.withBillingUnit(input, run.id);
|
||||
let finalEvent: NativeActionEvent | undefined;
|
||||
const attachments: unknown[] = [];
|
||||
try {
|
||||
const nativeInput = await this.prepareNativeInput({
|
||||
...input,
|
||||
...inputWithBillingUnit,
|
||||
});
|
||||
for await (const event of this.runNativeStream(
|
||||
{
|
||||
...nativeInput,
|
||||
recipeId: input.actionId,
|
||||
recipeVersion: input.actionVersion,
|
||||
recipeId: inputWithBillingUnit.actionId,
|
||||
recipeVersion: inputWithBillingUnit.actionVersion,
|
||||
},
|
||||
input.signal
|
||||
inputWithBillingUnit.signal
|
||||
)) {
|
||||
finalEvent = event;
|
||||
let projectedEvent = event;
|
||||
@@ -343,4 +344,40 @@ export class ActionRuntimeBridge {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private withBillingUnit(
|
||||
input: ActionRuntimeBridgeInput,
|
||||
billingUnitId: string
|
||||
): ActionRuntimeBridgeInput {
|
||||
return {
|
||||
...input,
|
||||
prepareStructuredRoutes: input.prepareStructuredRoutes
|
||||
? {
|
||||
...input.prepareStructuredRoutes,
|
||||
options: {
|
||||
...input.prepareStructuredRoutes.options,
|
||||
actionId:
|
||||
input.prepareStructuredRoutes.options?.actionId ??
|
||||
input.actionId,
|
||||
billingUnitId:
|
||||
input.prepareStructuredRoutes.options?.billingUnitId ??
|
||||
billingUnitId,
|
||||
},
|
||||
}
|
||||
: undefined,
|
||||
prepareImageRoutes: input.prepareImageRoutes
|
||||
? {
|
||||
...input.prepareImageRoutes,
|
||||
options: {
|
||||
...input.prepareImageRoutes.options,
|
||||
actionId:
|
||||
input.prepareImageRoutes.options?.actionId ?? input.actionId,
|
||||
billingUnitId:
|
||||
input.prepareImageRoutes.options?.billingUnitId ??
|
||||
billingUnitId,
|
||||
},
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,6 +411,7 @@ function stripHostOnlyOptions<TOptions extends object | undefined>(
|
||||
user: _user,
|
||||
session: _session,
|
||||
workspace: _workspace,
|
||||
quotaBackedRoutesAllowed: _quotaBackedRoutesAllowed,
|
||||
...serializable
|
||||
} = options as Record<string, unknown>;
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ export class ActionStreamHost {
|
||||
prepared.session,
|
||||
params,
|
||||
userId,
|
||||
parsedQuery.byokLeaseId,
|
||||
prepared.quotaBackedRoutesAllowed,
|
||||
signal
|
||||
);
|
||||
const runStream = this.bridge.runStream({
|
||||
@@ -130,6 +132,9 @@ export class ActionStreamHost {
|
||||
user: userId,
|
||||
workspace: prepared.session.config.workspaceId,
|
||||
session: sessionId,
|
||||
byokLeaseId: parsedQuery.byokLeaseId,
|
||||
quotaBackedRoutesAllowed: prepared.quotaBackedRoutesAllowed,
|
||||
featureKind: 'action',
|
||||
},
|
||||
},
|
||||
prepareImageRoutes: imageRoutes
|
||||
@@ -177,6 +182,8 @@ export class ActionStreamHost {
|
||||
session: ChatSession,
|
||||
params: Record<string, unknown>,
|
||||
userId: string,
|
||||
byokLeaseId?: string,
|
||||
quotaBackedRoutesAllowed?: boolean,
|
||||
signal?: AbortSignal
|
||||
): Promise<ImageActionRoutePreparation | undefined> {
|
||||
if (!isImageAction(actionId)) {
|
||||
@@ -201,6 +208,9 @@ export class ActionStreamHost {
|
||||
user: userId,
|
||||
workspace: session.config.workspaceId,
|
||||
session: session.config.sessionId,
|
||||
byokLeaseId,
|
||||
quotaBackedRoutesAllowed,
|
||||
featureKind: 'image',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -18,6 +18,10 @@ export type ChatSelectionOptions = {
|
||||
reasoning?: boolean;
|
||||
webSearch?: boolean;
|
||||
toolsConfig?: ToolsConfig;
|
||||
byokLeaseId?: string;
|
||||
billingUnitId?: string;
|
||||
featureKind?: 'chat' | 'action' | 'image';
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
type ResolvePolicyModelInput = ResolveModelInput & {
|
||||
@@ -97,6 +101,10 @@ export class CapabilityPolicyHost {
|
||||
user: session.config.userId,
|
||||
session: session.config.sessionId,
|
||||
workspace: session.config.workspaceId,
|
||||
byokLeaseId: options.byokLeaseId,
|
||||
billingUnitId: options.billingUnitId,
|
||||
featureKind: options.featureKind ?? 'chat',
|
||||
quotaBackedRoutesAllowed: options.quotaBackedRoutesAllowed,
|
||||
reasoning: options.reasoning,
|
||||
webSearch: options.webSearch,
|
||||
tools,
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
CopilotSessionNotFound,
|
||||
Mutex,
|
||||
} from '../../../../base';
|
||||
import { CopilotAccessPolicy } from '../../access';
|
||||
import { CompatSubmissionStore } from '../../compat/submission-store';
|
||||
import {
|
||||
canonicalizeTurnTrace,
|
||||
@@ -20,6 +21,12 @@ export type PreparedConversationTurn = {
|
||||
params: Record<string, string>;
|
||||
session: ChatSession;
|
||||
latestTurn?: Turn;
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
type AppendedSessionMessage = {
|
||||
turn?: Turn;
|
||||
quotaBackedRoutesAllowed?: boolean;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
@@ -27,7 +34,8 @@ export class ConversationHost {
|
||||
constructor(
|
||||
private readonly sessions: ChatSessionService,
|
||||
private readonly submissions: CompatSubmissionStore,
|
||||
private readonly mutex: Mutex
|
||||
private readonly mutex: Mutex,
|
||||
private readonly access: CopilotAccessPolicy
|
||||
) {}
|
||||
|
||||
private async loadAcceptedTurn(
|
||||
@@ -101,12 +109,32 @@ export class ConversationHost {
|
||||
session: ChatSession,
|
||||
sessionId: string,
|
||||
messageId?: string,
|
||||
retry = false
|
||||
): Promise<Turn | undefined> {
|
||||
retry = false,
|
||||
byokLeaseId?: string
|
||||
): Promise<AppendedSessionMessage> {
|
||||
const resolveChatRouteAccess = () =>
|
||||
this.access.resolveTurnRouteAccess({
|
||||
userId,
|
||||
workspaceId: session.config.workspaceId,
|
||||
byokLeaseId,
|
||||
featureKind: 'chat',
|
||||
});
|
||||
|
||||
if (!messageId) {
|
||||
await this.sessions.revertLatestMessage(sessionId, false);
|
||||
session.revertLatestMessage(false);
|
||||
return session.latestUserTurn;
|
||||
if (!session.latestUserTurn) {
|
||||
const routeAccess = await resolveChatRouteAccess();
|
||||
return {
|
||||
turn: session.latestUserTurn,
|
||||
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
|
||||
};
|
||||
}
|
||||
const routeAccess = await resolveChatRouteAccess();
|
||||
return {
|
||||
turn: session.latestUserTurn,
|
||||
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
|
||||
};
|
||||
}
|
||||
|
||||
const acceptedTurn = await this.loadAcceptedTurn(
|
||||
@@ -116,7 +144,7 @@ export class ConversationHost {
|
||||
retry
|
||||
);
|
||||
if (acceptedTurn) {
|
||||
return acceptedTurn;
|
||||
return { turn: acceptedTurn, quotaBackedRoutesAllowed: true };
|
||||
}
|
||||
|
||||
await using lock = await this.mutex.acquire(
|
||||
@@ -132,7 +160,9 @@ export class ConversationHost {
|
||||
messageId,
|
||||
retry
|
||||
);
|
||||
if (acceptedAfterLock) return acceptedAfterLock;
|
||||
if (acceptedAfterLock) {
|
||||
return { turn: acceptedAfterLock, quotaBackedRoutesAllowed: true };
|
||||
}
|
||||
|
||||
const durableTurn = await this.loadDurableTurn(
|
||||
session,
|
||||
@@ -140,9 +170,14 @@ export class ConversationHost {
|
||||
messageId,
|
||||
retry
|
||||
);
|
||||
if (durableTurn) return durableTurn;
|
||||
if (durableTurn) {
|
||||
return {
|
||||
turn: durableTurn,
|
||||
quotaBackedRoutesAllowed: true,
|
||||
};
|
||||
}
|
||||
|
||||
await this.sessions.checkQuota(userId);
|
||||
const routeAccess = await resolveChatRouteAccess();
|
||||
|
||||
const submission = await this.submissions.get(messageId);
|
||||
if (!submission || submission.sessionId !== sessionId) {
|
||||
@@ -176,7 +211,10 @@ export class ConversationHost {
|
||||
turnId: turn.id ?? '',
|
||||
});
|
||||
session.pushPersistedTurn(turn);
|
||||
return turn;
|
||||
return {
|
||||
turn,
|
||||
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
|
||||
};
|
||||
}
|
||||
|
||||
async prepareTurn(
|
||||
@@ -184,27 +222,30 @@ export class ConversationHost {
|
||||
sessionId: string,
|
||||
query: Record<string, string | string[]>
|
||||
): Promise<PreparedConversationTurn> {
|
||||
const { messageId, retry, params } = ChatQuerySchema.parse(query);
|
||||
const { messageId, retry, params, byokLeaseId } =
|
||||
ChatQuerySchema.parse(query);
|
||||
const session = await this.sessions.get(sessionId);
|
||||
if (!session || session.config.userId !== userId) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
const latestMessage = await this.appendSessionMessage(
|
||||
const appended = await this.appendSessionMessage(
|
||||
userId,
|
||||
session,
|
||||
sessionId,
|
||||
messageId,
|
||||
retry
|
||||
retry,
|
||||
byokLeaseId
|
||||
);
|
||||
const currentUserMessage =
|
||||
session.stashTurns.findLast(turn => turn.role === 'user') ??
|
||||
latestMessage;
|
||||
appended.turn;
|
||||
|
||||
return {
|
||||
messageId,
|
||||
params,
|
||||
session,
|
||||
latestTurn: currentUserMessage,
|
||||
quotaBackedRoutesAllowed: appended.quotaBackedRoutesAllowed,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
|
||||
import { NoCopilotProviderAvailable } from '../../../base';
|
||||
import {
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
llmValidateJsonSchema,
|
||||
parseNativeStructuredOutput,
|
||||
} from '../../../native';
|
||||
import { type ByokFeatureKind, ByokService } from '../byok';
|
||||
import { type StreamObject } from '../providers/types';
|
||||
import { CopilotExecutionMetrics } from './execution-metrics';
|
||||
import {
|
||||
@@ -25,8 +26,11 @@ import { mapNativeSemanticError } from './native-errors';
|
||||
import {
|
||||
createNativeToolLoopAdapter,
|
||||
NativeProviderAdapter,
|
||||
type NativeProviderAdapterOptions,
|
||||
} from './tool/native-adapter';
|
||||
|
||||
const logger = new Logger('NativeExecutionEngine');
|
||||
|
||||
function modelIdForError(modelId?: string) {
|
||||
return modelId ?? 'auto';
|
||||
}
|
||||
@@ -60,6 +64,83 @@ function extractTextResponse(response: LlmDispatchResponse) {
|
||||
.trim();
|
||||
}
|
||||
|
||||
function getUsageContext(plan: ExecutionPlan) {
|
||||
const options = 'options' in plan.request ? plan.request.options : undefined;
|
||||
const requestFeatureKind =
|
||||
plan.request.kind === 'text' ||
|
||||
plan.request.kind === 'streamText' ||
|
||||
plan.request.kind === 'streamObject'
|
||||
? 'chat'
|
||||
: plan.request.kind;
|
||||
return {
|
||||
workspaceId: options?.workspace,
|
||||
userId: options?.user,
|
||||
sessionId: options?.session,
|
||||
taskId: options?.taskId,
|
||||
actionId: options?.actionId,
|
||||
billingUnitId: options?.billingUnitId,
|
||||
featureKind: options?.featureKind ?? requestFeatureKind,
|
||||
};
|
||||
}
|
||||
|
||||
async function recordByokUsage(
|
||||
byok: ByokService,
|
||||
plan: ExecutionPlan,
|
||||
input: {
|
||||
providerId?: string;
|
||||
model?: string | null;
|
||||
usage?: LlmDispatchResponse['usage'];
|
||||
}
|
||||
) {
|
||||
const context = getUsageContext(plan);
|
||||
try {
|
||||
await byok.recordUsage({
|
||||
workspaceId: context.workspaceId,
|
||||
userId: context.userId,
|
||||
sessionId: context.sessionId,
|
||||
taskId: context.taskId,
|
||||
actionId: context.actionId,
|
||||
billingUnitId: context.billingUnitId,
|
||||
featureKind: context.featureKind as ByokFeatureKind,
|
||||
providerId: input.providerId,
|
||||
model: input.model,
|
||||
usage: input.usage,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Failed to record BYOK usage: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function recordSingleByokRouteFailure(
|
||||
byok: ByokService,
|
||||
plan: ExecutionPlan,
|
||||
error: unknown
|
||||
) {
|
||||
const [providerId] = plan.routePolicy.fallbackOrder;
|
||||
if (plan.routePolicy.fallbackOrder.length !== 1 || !providerId) {
|
||||
return;
|
||||
}
|
||||
const context = getUsageContext(plan);
|
||||
try {
|
||||
await byok.recordProviderFailure({
|
||||
workspaceId: context.workspaceId,
|
||||
providerId,
|
||||
featureKind: context.featureKind as ByokFeatureKind,
|
||||
error,
|
||||
});
|
||||
} catch (recordError) {
|
||||
logger.warn(
|
||||
`Failed to record BYOK provider failure: ${
|
||||
recordError instanceof Error ? recordError.message : String(recordError)
|
||||
}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
function recordPreparedDispatch(
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
plan: ExecutionPlan,
|
||||
@@ -72,7 +153,12 @@ function recordPreparedDispatch(
|
||||
);
|
||||
}
|
||||
|
||||
function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) {
|
||||
function createNativeChatAdapter(
|
||||
dispatch: NativeChatDispatchPlan,
|
||||
options?: {
|
||||
onUsage?: NativeProviderAdapterOptions['onUsage'];
|
||||
}
|
||||
) {
|
||||
if (dispatch.hasTools) {
|
||||
return createNativeToolLoopAdapter(
|
||||
{ preparedRoutes: dispatch.routes },
|
||||
@@ -80,6 +166,7 @@ function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) {
|
||||
{
|
||||
maxSteps: dispatch.prepared.maxSteps,
|
||||
nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware,
|
||||
onUsage: options?.onUsage,
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -95,6 +182,7 @@ function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) {
|
||||
|
||||
return new NativeProviderAdapter(nativeDispatch, {
|
||||
nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware,
|
||||
onUsage: options?.onUsage,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -102,30 +190,38 @@ async function runPreparedValuePlan<TResult>(
|
||||
plan: ExecutionPlan,
|
||||
routeCount: number,
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
run: () => Promise<TResult>
|
||||
run: () => Promise<TResult>,
|
||||
byok: ByokService
|
||||
) {
|
||||
recordPreparedDispatch(executionMetrics, plan, routeCount);
|
||||
try {
|
||||
return await run();
|
||||
} catch (error) {
|
||||
throw mapNativeSemanticError(error);
|
||||
const mapped = mapNativeSemanticError(error);
|
||||
await recordSingleByokRouteFailure(byok, plan, mapped);
|
||||
throw mapped;
|
||||
}
|
||||
}
|
||||
|
||||
async function* mapPreparedStreamErrors<T>(
|
||||
source: AsyncIterable<T>
|
||||
source: AsyncIterable<T>,
|
||||
plan: ExecutionPlan,
|
||||
byok: ByokService
|
||||
): AsyncIterableIterator<T> {
|
||||
try {
|
||||
yield* source;
|
||||
} catch (error) {
|
||||
throw mapNativeSemanticError(error);
|
||||
const mapped = mapNativeSemanticError(error);
|
||||
await recordSingleByokRouteFailure(byok, plan, mapped);
|
||||
throw mapped;
|
||||
}
|
||||
}
|
||||
|
||||
async function runChatValuePlan(
|
||||
plan: ExecutionPlan,
|
||||
dispatch: NativeChatDispatchPlan,
|
||||
executionMetrics?: CopilotExecutionMetrics
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
byok: ByokService
|
||||
) {
|
||||
const adapter = createNativeChatAdapter(dispatch);
|
||||
return await runPreparedValuePlan(
|
||||
@@ -140,6 +236,11 @@ async function runChatValuePlan(
|
||||
const result = await llmDispatchPlan({
|
||||
preparedRoutes: dispatch.routes,
|
||||
});
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: result.provider_id,
|
||||
model: result.response.model,
|
||||
usage: result.response.usage,
|
||||
});
|
||||
return extractTextResponse(result.response);
|
||||
}
|
||||
|
||||
@@ -152,16 +253,26 @@ async function runChatValuePlan(
|
||||
plan.hostContext.signal,
|
||||
plan.request.messages
|
||||
);
|
||||
}
|
||||
},
|
||||
byok
|
||||
);
|
||||
}
|
||||
|
||||
async function* runChatStreamPlan(
|
||||
plan: ExecutionPlan,
|
||||
dispatch: NativeChatDispatchPlan,
|
||||
executionMetrics?: CopilotExecutionMetrics
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
byok: ByokService
|
||||
): AsyncIterableIterator<string | StreamObject> {
|
||||
const adapter = createNativeChatAdapter(dispatch);
|
||||
const adapter = createNativeChatAdapter(dispatch, {
|
||||
onUsage: async usage => {
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: usage.providerId,
|
||||
model: usage.model,
|
||||
usage: usage.usage,
|
||||
});
|
||||
},
|
||||
});
|
||||
recordPreparedDispatch(executionMetrics, plan, dispatch.routes.length);
|
||||
|
||||
if (plan.request.kind === 'streamText') {
|
||||
@@ -170,7 +281,9 @@ async function* runChatStreamPlan(
|
||||
dispatch.prepared.request,
|
||||
plan.hostContext.signal,
|
||||
plan.request.messages
|
||||
)
|
||||
),
|
||||
plan,
|
||||
byok
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -181,7 +294,9 @@ async function* runChatStreamPlan(
|
||||
dispatch.prepared.request,
|
||||
plan.hostContext.signal,
|
||||
plan.request.messages
|
||||
)
|
||||
),
|
||||
plan,
|
||||
byok
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -192,7 +307,8 @@ async function* runChatStreamPlan(
|
||||
async function* runPreparedImageArtifactPlan(
|
||||
dispatch: NativeImageDispatchPlan,
|
||||
plan: ExecutionPlan,
|
||||
executionMetrics?: CopilotExecutionMetrics
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
byok: ByokService
|
||||
): AsyncIterableIterator<NativeImageArtifact> {
|
||||
if (plan.request.kind !== 'image') {
|
||||
throw new Error('image dispatch requires image plan');
|
||||
@@ -204,8 +320,21 @@ async function* runPreparedImageArtifactPlan(
|
||||
result = await llmImageDispatchPlan({
|
||||
preparedRoutes: dispatch.routes,
|
||||
});
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: result.provider_id,
|
||||
model: dispatch.prepared.route.model,
|
||||
usage: result.response.usage
|
||||
? {
|
||||
prompt_tokens: result.response.usage.input_tokens ?? 0,
|
||||
completion_tokens: result.response.usage.output_tokens ?? 0,
|
||||
total_tokens: result.response.usage.total_tokens ?? 0,
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
} catch (error) {
|
||||
throw mapNativeSemanticError(error);
|
||||
const mapped = mapNativeSemanticError(error);
|
||||
await recordSingleByokRouteFailure(byok, plan, mapped);
|
||||
throw mapped;
|
||||
}
|
||||
for (const artifact of result.response.images) {
|
||||
yield artifact;
|
||||
@@ -214,13 +343,14 @@ async function* runPreparedImageArtifactPlan(
|
||||
|
||||
async function executePreparedPlan(
|
||||
plan: ExecutionPlan,
|
||||
executionMetrics?: CopilotExecutionMetrics
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
byok: ByokService
|
||||
): Promise<string | number[][] | number[] | null> {
|
||||
switch (plan.request.kind) {
|
||||
case 'text': {
|
||||
const dispatch = plan.nativeDispatch?.chat;
|
||||
return dispatch
|
||||
? await runChatValuePlan(plan, dispatch, executionMetrics)
|
||||
? await runChatValuePlan(plan, dispatch, executionMetrics, byok)
|
||||
: null;
|
||||
}
|
||||
case 'structured': {
|
||||
@@ -236,13 +366,19 @@ async function executePreparedPlan(
|
||||
const result = await llmStructuredDispatchPlan({
|
||||
preparedRoutes: dispatch.routes,
|
||||
});
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: result.provider_id,
|
||||
model: result.response.model,
|
||||
usage: result.response.usage,
|
||||
});
|
||||
const parsed = parseNativeStructuredOutput(result.response);
|
||||
const validated = llmValidateJsonSchema(
|
||||
dispatch.prepared.request.schema,
|
||||
parsed
|
||||
);
|
||||
return JSON.stringify(validated);
|
||||
}
|
||||
},
|
||||
byok
|
||||
);
|
||||
}
|
||||
case 'embedding': {
|
||||
@@ -258,8 +394,20 @@ async function executePreparedPlan(
|
||||
const result = await llmEmbeddingDispatchPlan({
|
||||
preparedRoutes: dispatch.routes,
|
||||
});
|
||||
return result.response.embeddings;
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: result.provider_id,
|
||||
model: result.response.model,
|
||||
usage: result.response.usage
|
||||
? {
|
||||
prompt_tokens: result.response.usage.prompt_tokens,
|
||||
completion_tokens: 0,
|
||||
total_tokens: result.response.usage.total_tokens,
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
return result.response.embeddings;
|
||||
},
|
||||
byok
|
||||
);
|
||||
}
|
||||
case 'rerank': {
|
||||
@@ -275,8 +423,13 @@ async function executePreparedPlan(
|
||||
const result = await llmRerankDispatchPlan({
|
||||
preparedRoutes: dispatch.routes,
|
||||
});
|
||||
await recordByokUsage(byok, plan, {
|
||||
providerId: result.provider_id,
|
||||
model: result.response.model,
|
||||
});
|
||||
return result.response.scores;
|
||||
}
|
||||
},
|
||||
byok
|
||||
);
|
||||
}
|
||||
default:
|
||||
@@ -286,14 +439,15 @@ async function executePreparedPlan(
|
||||
|
||||
function executePreparedStreamPlan(
|
||||
plan: ExecutionPlan,
|
||||
executionMetrics?: CopilotExecutionMetrics
|
||||
executionMetrics: CopilotExecutionMetrics | undefined,
|
||||
byok: ByokService
|
||||
): AsyncIterableIterator<string | StreamObject> | null {
|
||||
switch (plan.request.kind) {
|
||||
case 'streamText':
|
||||
case 'streamObject': {
|
||||
const dispatch = plan.nativeDispatch?.chat;
|
||||
return dispatch
|
||||
? runChatStreamPlan(plan, dispatch, executionMetrics)
|
||||
? runChatStreamPlan(plan, dispatch, executionMetrics, byok)
|
||||
: null;
|
||||
}
|
||||
default:
|
||||
@@ -312,7 +466,10 @@ function noRouteStream<T>(plan: ExecutionPlan) {
|
||||
|
||||
@Injectable()
|
||||
export class NativeExecutionEngine {
|
||||
constructor(private readonly executionMetrics?: CopilotExecutionMetrics) {}
|
||||
constructor(
|
||||
private readonly byok: ByokService,
|
||||
private readonly executionMetrics?: CopilotExecutionMetrics
|
||||
) {}
|
||||
|
||||
private noRoute(plan: ExecutionPlan): never {
|
||||
throw new NoCopilotProviderAvailable({
|
||||
@@ -328,7 +485,11 @@ export class NativeExecutionEngine {
|
||||
async execute(
|
||||
plan: ExecutionPlanForKind<ValueExecutionKind>
|
||||
): Promise<string | number[][] | number[]> {
|
||||
const result = await executePreparedPlan(plan, this.executionMetrics);
|
||||
const result = await executePreparedPlan(
|
||||
plan,
|
||||
this.executionMetrics,
|
||||
this.byok
|
||||
);
|
||||
if (result === null) {
|
||||
return this.noRoute(plan);
|
||||
}
|
||||
@@ -345,7 +506,11 @@ export class NativeExecutionEngine {
|
||||
executeStream(
|
||||
plan: ExecutionPlanForKind<StreamExecutionKind>
|
||||
): AsyncIterableIterator<string | StreamObject> {
|
||||
const result = executePreparedStreamPlan(plan, this.executionMetrics);
|
||||
const result = executePreparedStreamPlan(
|
||||
plan,
|
||||
this.executionMetrics,
|
||||
this.byok
|
||||
);
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
@@ -361,7 +526,8 @@ export class NativeExecutionEngine {
|
||||
return runPreparedImageArtifactPlan(
|
||||
dispatch,
|
||||
plan,
|
||||
this.executionMetrics
|
||||
this.executionMetrics,
|
||||
this.byok
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import {
|
||||
} from '../providers/types';
|
||||
import {
|
||||
buildBlobContentGetter,
|
||||
buildContentGetter,
|
||||
buildDocContentGetter,
|
||||
buildDocCreateHandler,
|
||||
buildDocKeywordSearchGetter,
|
||||
@@ -27,7 +26,6 @@ import {
|
||||
createConversationSummaryTool,
|
||||
createDocComposeTool,
|
||||
createDocCreateTool,
|
||||
createDocEditTool,
|
||||
createDocKeywordSearchTool,
|
||||
createDocReadTool,
|
||||
createDocSemanticSearchTool,
|
||||
@@ -68,6 +66,21 @@ export class ToolRuntime {
|
||||
if (!options?.tools?.length) {
|
||||
return tools;
|
||||
}
|
||||
const runPromptText = (
|
||||
promptName: string,
|
||||
params: Record<string, unknown>
|
||||
) =>
|
||||
this.promptRuntime.runText(promptName, params, {
|
||||
providerOptions: {
|
||||
user: options.user,
|
||||
session: options.session,
|
||||
workspace: options.workspace,
|
||||
byokLeaseId: options.byokLeaseId,
|
||||
billingUnitId: options.billingUnitId,
|
||||
quotaBackedRoutesAllowed: options.quotaBackedRoutesAllowed,
|
||||
featureKind: options.featureKind,
|
||||
},
|
||||
});
|
||||
|
||||
for (const tool of options.tools) {
|
||||
const toolDef = resolveProviderSpecificTool?.(tool, model);
|
||||
@@ -97,23 +110,13 @@ export class ToolRuntime {
|
||||
break;
|
||||
}
|
||||
case 'codeArtifact': {
|
||||
tools.code_artifact = createCodeArtifactTool(
|
||||
this.promptRuntime.runText.bind(this.promptRuntime)
|
||||
);
|
||||
tools.code_artifact = createCodeArtifactTool(runPromptText);
|
||||
break;
|
||||
}
|
||||
case 'conversationSummary': {
|
||||
tools.conversation_summary = createConversationSummaryTool(
|
||||
options.session,
|
||||
this.promptRuntime.runText.bind(this.promptRuntime)
|
||||
);
|
||||
break;
|
||||
}
|
||||
case 'docEdit': {
|
||||
const getDocContent = buildContentGetter(this.ac, this.docReader);
|
||||
tools.doc_edit = createDocEditTool(
|
||||
this.promptRuntime.runText.bind(this.promptRuntime),
|
||||
getDocContent.bind(null, options)
|
||||
runPromptText
|
||||
);
|
||||
break;
|
||||
}
|
||||
@@ -177,15 +180,11 @@ export class ToolRuntime {
|
||||
break;
|
||||
}
|
||||
case 'docCompose': {
|
||||
tools.doc_compose = createDocComposeTool(
|
||||
this.promptRuntime.runText.bind(this.promptRuntime)
|
||||
);
|
||||
tools.doc_compose = createDocComposeTool(runPromptText);
|
||||
break;
|
||||
}
|
||||
case 'sectionEdit': {
|
||||
tools.section_edit = createSectionEditTool(
|
||||
this.promptRuntime.runText.bind(this.promptRuntime)
|
||||
);
|
||||
tools.section_edit = createSectionEditTool(runPromptText);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
import type { LlmRequest, LlmToolLoopStreamEvent } from '../../../../native';
|
||||
import type { NodeTextMiddleware } from '../../config';
|
||||
import type { PromptMessage, StreamObject } from '../../providers/types';
|
||||
@@ -20,9 +22,14 @@ type AttachmentFootnote = {
|
||||
fileType: string;
|
||||
};
|
||||
|
||||
type NativeProviderAdapterOptions = {
|
||||
export type NativeProviderAdapterOptions = {
|
||||
maxSteps?: number;
|
||||
nodeTextMiddleware?: NodeTextMiddleware[];
|
||||
onUsage?: (input: {
|
||||
providerId: string;
|
||||
model?: string;
|
||||
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
|
||||
}) => void | Promise<void>;
|
||||
};
|
||||
|
||||
type NativeStreamDispatch = ConstructorParameters<
|
||||
@@ -103,9 +110,11 @@ function formatAttachmentFootnotes(
|
||||
}
|
||||
|
||||
export class NativeProviderAdapter {
|
||||
readonly logger = new Logger(NativeProviderAdapter.name);
|
||||
readonly #runtime: NativeRuntimeAdapter;
|
||||
readonly #enableCallout: boolean;
|
||||
readonly #enableCitationFootnote: boolean;
|
||||
readonly #onUsage?: NativeProviderAdapterOptions['onUsage'];
|
||||
|
||||
constructor(
|
||||
dispatchWithTools: NativeStreamDispatch,
|
||||
@@ -120,6 +129,36 @@ export class NativeProviderAdapter {
|
||||
enabledNodeTextMiddlewares.has('thinking_format');
|
||||
this.#enableCitationFootnote =
|
||||
enabledNodeTextMiddlewares.has('citation_footnote');
|
||||
this.#onUsage = options.onUsage;
|
||||
}
|
||||
|
||||
async #recordUsageOnProviderSelected(
|
||||
event: { type: string; [key: string]: unknown },
|
||||
state: {
|
||||
model?: string;
|
||||
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
|
||||
}
|
||||
) {
|
||||
if (
|
||||
event.type !== 'provider_selected' ||
|
||||
typeof event.provider_id !== 'string'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await this.#onUsage?.({
|
||||
providerId: event.provider_id,
|
||||
model: state.model,
|
||||
usage: state.usage,
|
||||
});
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
`Provider usage callback failed: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`
|
||||
);
|
||||
}
|
||||
state.usage = undefined;
|
||||
}
|
||||
|
||||
async text(
|
||||
@@ -144,6 +183,10 @@ export class NativeProviderAdapter {
|
||||
? new CitationFootnoteFormatter()
|
||||
: null;
|
||||
let streamPartId = 0;
|
||||
const usageState: {
|
||||
model?: string;
|
||||
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
|
||||
} = {};
|
||||
|
||||
for await (const event of this.#runtime.streamEvents(
|
||||
request,
|
||||
@@ -151,6 +194,22 @@ export class NativeProviderAdapter {
|
||||
messages
|
||||
)) {
|
||||
switch (event.type) {
|
||||
case 'message_start': {
|
||||
const startEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'message_start' }
|
||||
>;
|
||||
usageState.model = startEvent.model;
|
||||
break;
|
||||
}
|
||||
case 'usage': {
|
||||
const usageEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'usage' }
|
||||
>;
|
||||
usageState.usage = usageEvent.usage;
|
||||
break;
|
||||
}
|
||||
case 'text_delta': {
|
||||
const textEvent = event as unknown as { text: string };
|
||||
if (textParser) {
|
||||
@@ -216,6 +275,11 @@ export class NativeProviderAdapter {
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const doneEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'done' }
|
||||
>;
|
||||
usageState.usage = doneEvent.usage ?? usageState.usage;
|
||||
const footnotes = textParser?.end() ?? '';
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
const tails = [citations, footnotes].filter(Boolean).join('\n');
|
||||
@@ -224,6 +288,9 @@ export class NativeProviderAdapter {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'provider_selected':
|
||||
await this.#recordUsageOnProviderSelected(event, usageState);
|
||||
break;
|
||||
case 'error':
|
||||
throw new Error(
|
||||
typeof event.message === 'string'
|
||||
@@ -246,6 +313,10 @@ export class NativeProviderAdapter {
|
||||
: null;
|
||||
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
|
||||
let hasFootnoteReference = false;
|
||||
const usageState: {
|
||||
model?: string;
|
||||
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
|
||||
} = {};
|
||||
|
||||
for await (const event of this.#runtime.streamEvents(
|
||||
request,
|
||||
@@ -253,6 +324,22 @@ export class NativeProviderAdapter {
|
||||
messages
|
||||
)) {
|
||||
switch (event.type) {
|
||||
case 'message_start': {
|
||||
const startEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'message_start' }
|
||||
>;
|
||||
usageState.model = startEvent.model;
|
||||
break;
|
||||
}
|
||||
case 'usage': {
|
||||
const usageEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'usage' }
|
||||
>;
|
||||
usageState.usage = usageEvent.usage;
|
||||
break;
|
||||
}
|
||||
case 'text_delta': {
|
||||
const textEvent = event as unknown as { text: string };
|
||||
if (textEvent.text.includes('[^')) {
|
||||
@@ -302,6 +389,11 @@ export class NativeProviderAdapter {
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
const doneEvent = event as Extract<
|
||||
LlmToolLoopStreamEvent,
|
||||
{ type: 'done' }
|
||||
>;
|
||||
usageState.usage = doneEvent.usage ?? usageState.usage;
|
||||
const citations = citationFormatter?.end() ?? '';
|
||||
if (citations) {
|
||||
hasFootnoteReference = true;
|
||||
@@ -318,6 +410,9 @@ export class NativeProviderAdapter {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'provider_selected':
|
||||
await this.#recordUsageOnProviderSelected(event, usageState);
|
||||
break;
|
||||
case 'error':
|
||||
throw new Error(
|
||||
typeof event.message === 'string'
|
||||
|
||||
@@ -62,7 +62,7 @@ export class TurnOrchestrator {
|
||||
sessionId,
|
||||
query
|
||||
);
|
||||
const { modelId, reasoning, webSearch, toolsConfig } =
|
||||
const { modelId, reasoning, webSearch, toolsConfig, byokLeaseId } =
|
||||
ChatQuerySchema.parse(query);
|
||||
const promptParams = await this.buildPromptParams(sessionId, {
|
||||
latestTurn: prepared.latestTurn,
|
||||
@@ -82,6 +82,15 @@ export class TurnOrchestrator {
|
||||
reasoning,
|
||||
webSearch,
|
||||
toolsConfig,
|
||||
byokLeaseId,
|
||||
billingUnitId: prepared.latestTurn?.id,
|
||||
quotaBackedRoutesAllowed: prepared.quotaBackedRoutesAllowed,
|
||||
featureKind:
|
||||
selection.responseMode === 'image'
|
||||
? 'image'
|
||||
: selection.responseMode === 'object'
|
||||
? 'action'
|
||||
: 'chat',
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionOptions,
|
||||
} from '../../models';
|
||||
import { CopilotAccessPolicy } from './access';
|
||||
import { ConversationPolicy } from './conversation/policy';
|
||||
import { ConversationStore } from './conversation/store';
|
||||
import { type Conversation, promptMessageFromTurn, type Turn } from './core';
|
||||
@@ -186,6 +187,7 @@ export class ChatSessionService {
|
||||
private readonly models: Models,
|
||||
private readonly jobs: JobQueue,
|
||||
private readonly store: ConversationStore,
|
||||
private readonly access: CopilotAccessPolicy,
|
||||
private readonly conversationPolicy: ConversationPolicy,
|
||||
private readonly prompts: PromptService,
|
||||
private readonly promptRuntime: PromptRuntime
|
||||
@@ -298,11 +300,11 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
async getQuota(userId: string) {
|
||||
return await this.conversationPolicy.getQuota(userId);
|
||||
return await this.access.getQuota(userId);
|
||||
}
|
||||
|
||||
async checkQuota(userId: string) {
|
||||
await this.conversationPolicy.checkQuota(userId);
|
||||
await this.access.checkQuota(userId);
|
||||
}
|
||||
|
||||
async create(options: ChatSessionOptions): Promise<string> {
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { DocReader } from '../../../core/doc';
|
||||
import { AccessController } from '../../../core/permission';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
type RunPromptText = (
|
||||
promptName: string,
|
||||
params: Record<string, unknown>
|
||||
) => Promise<string>;
|
||||
|
||||
const CodeEditSchema = z
|
||||
.array(
|
||||
z.object({
|
||||
op: z
|
||||
.string()
|
||||
.describe(
|
||||
'A short description of the change, such as "Bold intro name"'
|
||||
),
|
||||
updates: z
|
||||
.string()
|
||||
.describe(
|
||||
'Markdown block fragments that represent the change, including the block_id and type'
|
||||
),
|
||||
})
|
||||
)
|
||||
.describe(
|
||||
'An array of independent semantic changes to apply to the document.'
|
||||
);
|
||||
|
||||
export const buildContentGetter = (ac: AccessController, doc: DocReader) => {
|
||||
const getDocContent = async (options: CopilotChatOptions, docId?: string) => {
|
||||
if (!options || !docId || !options.user || !options.workspace) {
|
||||
return undefined;
|
||||
}
|
||||
const canAccess = await ac
|
||||
.user(options.user)
|
||||
.workspace(options.workspace)
|
||||
.doc(docId)
|
||||
.can('Doc.Read');
|
||||
if (!canAccess) return undefined;
|
||||
const content = await doc.getDocMarkdown(options.workspace, docId, true);
|
||||
return content?.markdown.trim() || undefined;
|
||||
};
|
||||
return getDocContent;
|
||||
};
|
||||
|
||||
export const createDocEditTool = (
|
||||
prompt: RunPromptText,
|
||||
getContent: (targetId?: string) => Promise<string | undefined>
|
||||
) => {
|
||||
return defineTool({
|
||||
description: `
|
||||
Use this tool to propose an edit to a structured Markdown document with identifiable blocks.
|
||||
Each block begins with a comment like <!-- block_id=... -->, and represents a unit of editable content such as a heading, paragraph, list, or code snippet.
|
||||
This will be read by a less intelligent model, which will quickly apply the edit. You should make it clear what the edit is, while also minimizing the unchanged code you write.
|
||||
|
||||
If you receive a markdown without block_id comments, you should call \`doc_read\` tool to get the content.
|
||||
|
||||
Your task is to return a list of block-level changes needed to fulfill the user's intent. **Each change in code_edit must be completely independent: each code_edit entry should only perform a single, isolated change, and must not include the effects of other changes. For example, the updates for a delete operation should only show the context related to the deletion, and must not include any content modified by other operations (such as bolding or insertion). This ensures that each change can be applied independently and in any order.**
|
||||
|
||||
Each change should correspond to a specific user instruction and be represented by one of the following operations:
|
||||
|
||||
replace: Replace the content of a block with updated Markdown.
|
||||
|
||||
delete: Remove a block entirely.
|
||||
|
||||
insert: Add a new block, and specify its block_id and content.
|
||||
|
||||
Important Instructions:
|
||||
- Use the existing block structure as-is. Do not reformat or reorder blocks unless explicitly asked.
|
||||
- When inserting, follow the same format as a replacement, but ensure the new block_id does not conflict with existing IDs.
|
||||
- When replacing content, always keep the original block_id unchanged.
|
||||
- When deleting content, only use the format <!-- delete block_id=xxx -->, and only for valid block_id present in the original <code> content.
|
||||
- Each top-level list item should be a block. Like this:
|
||||
\`\`\`markdown
|
||||
<!-- block_id=001 flavour=affine:list -->
|
||||
* Item 1
|
||||
* SubItem 1
|
||||
<!-- block_id=002 flavour=affine:list -->
|
||||
1. Item 1
|
||||
1. SubItem 1
|
||||
\`\`\`
|
||||
- Your task is to return a list of block-level changes needed to fulfill the user's intent.
|
||||
- **Each change in code_edit must be completely independent: each code_edit entry should only perform a single, isolated change, and must not include the effects of other changes. For example, the updates for a delete operation should only show the context related to the deletion, and must not include any content modified by other operations (such as bolding or insertion). This ensures that each change can be applied independently and in any order.**
|
||||
|
||||
Original Content:
|
||||
\`\`\`markdown
|
||||
<!-- block_id=001 flavour=paragraph -->
|
||||
# Andriy Shevchenko
|
||||
|
||||
<!-- block_id=002 flavour=paragraph -->
|
||||
## Player Profile
|
||||
|
||||
<!-- block_id=003 flavour=paragraph -->
|
||||
Andriy Shevchenko is a legendary Ukrainian striker, best known for his time at AC Milan and Dynamo Kyiv. He won the Ballon d'Or in 2004.
|
||||
|
||||
<!-- block_id=004 flavour=paragraph -->
|
||||
## Career Overview
|
||||
|
||||
<!-- block_id=005 flavour=list -->
|
||||
- Born in 1976 in Ukraine.
|
||||
<!-- block_id=006 flavour=list -->
|
||||
- Rose to fame at Dynamo Kyiv in the 1990s.
|
||||
<!-- block_id=007 flavour=list -->
|
||||
- Starred at AC Milan (1999–2006), scoring over 170 goals.
|
||||
<!-- block_id=008 flavour=list -->
|
||||
- Played for Chelsea (2006–2009) before returning to Kyiv.
|
||||
<!-- block_id=009 flavour=list -->
|
||||
- Coached Ukraine national team, reaching Euro 2020 quarter-finals.
|
||||
\`\`\`
|
||||
|
||||
User Request:
|
||||
\`\`\`
|
||||
Bold the player’s name in the intro, add a summary section at the end, and remove the career overview.
|
||||
\`\`\`
|
||||
|
||||
Example response:
|
||||
\`\`\`json
|
||||
[
|
||||
{
|
||||
"op": "Bold the player's name in the introduction",
|
||||
"updates": "
|
||||
<!-- block_id=003 flavour=paragraph -->
|
||||
**Andriy Shevchenko** is a legendary Ukrainian striker, best known for his time at AC Milan and Dynamo Kyiv. He won the Ballon d'Or in 2004.
|
||||
"
|
||||
},
|
||||
{
|
||||
"op": "Add a summary section at the end",
|
||||
"updates": "
|
||||
<!-- block_id=new-abc123 flavour=paragraph -->
|
||||
## Summary
|
||||
<!-- block_id=new-def456 flavour=paragraph -->
|
||||
Shevchenko is celebrated as one of the greatest Ukrainian footballers of all time. Known for his composure, strength, and goal-scoring instinct, he left a lasting legacy both on and off the pitch.
|
||||
"
|
||||
},
|
||||
{
|
||||
"op": "Delete the career overview section",
|
||||
"updates": "
|
||||
<!-- delete block_id=004 -->
|
||||
<!-- delete block_id=005 -->
|
||||
<!-- delete block_id=006 -->
|
||||
<!-- delete block_id=007 -->
|
||||
<!-- delete block_id=008 -->
|
||||
<!-- delete block_id=009 -->
|
||||
"
|
||||
}
|
||||
]
|
||||
\`\`\`
|
||||
You should specify the following arguments before the others: [doc_id], [origin_content]
|
||||
|
||||
`,
|
||||
inputSchema: z.object({
|
||||
doc_id: z
|
||||
.string()
|
||||
.describe(
|
||||
'The unique ID of the document being edited. Required when editing an existing document stored in the system. If you are editing ad-hoc Markdown content instead, leave this empty and use origin_content.'
|
||||
)
|
||||
.optional(),
|
||||
|
||||
origin_content: z
|
||||
.string()
|
||||
.describe(
|
||||
'The full original Markdown content, including all block_id comments (e.g., <!-- block_id=block-001 type=paragraph -->). Required when doc_id is not provided. This content will be parsed into discrete editable blocks.'
|
||||
)
|
||||
.optional(),
|
||||
|
||||
instructions: z
|
||||
.string()
|
||||
.describe(
|
||||
'A short, first-person description of the intended edit, clearly summarizing what I will change. For example: "I will translate the steps into English and delete the paragraph explaining the delay." This helps the downstream system understand the purpose of the changes.'
|
||||
),
|
||||
|
||||
code_edit: z.preprocess(val => {
|
||||
// BACKGROUND: LLM sometimes returns a JSON string instead of an array.
|
||||
if (typeof val === 'string') {
|
||||
return JSON.parse(val);
|
||||
}
|
||||
return val;
|
||||
}, CodeEditSchema) as unknown as typeof CodeEditSchema,
|
||||
}),
|
||||
execute: async ({ doc_id, origin_content, code_edit }) => {
|
||||
try {
|
||||
const content = origin_content || (await getContent(doc_id));
|
||||
if (!content) {
|
||||
return 'Doc not found or doc is empty';
|
||||
}
|
||||
|
||||
const changedContents = await Promise.all(
|
||||
code_edit.map(async edit => {
|
||||
return await prompt('Apply Updates', {
|
||||
content,
|
||||
op: edit.op,
|
||||
updates: edit.updates,
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
return {
|
||||
result: changedContents.map((changedContent, index) => ({
|
||||
op: code_edit[index].op,
|
||||
updates: code_edit[index].updates,
|
||||
originalContent: content,
|
||||
changedContent,
|
||||
})),
|
||||
};
|
||||
} catch {
|
||||
return 'Failed to apply edit to the doc';
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -13,6 +13,11 @@ import { toolError } from './error';
|
||||
import { defineTool } from './tool';
|
||||
import type { CopilotChatOptions } from './types';
|
||||
|
||||
const getEmbeddingRouteContext = (options: CopilotChatOptions) => ({
|
||||
userId: options?.user,
|
||||
byokLeaseId: options?.byokLeaseId,
|
||||
});
|
||||
|
||||
export const buildDocSearchGetter = (
|
||||
ac: AccessController,
|
||||
context: CopilotContextService,
|
||||
@@ -43,12 +48,32 @@ export const buildDocSearchGetter = (
|
||||
'Doc Semantic Search Failed',
|
||||
'You do not have permission to access this workspace.'
|
||||
);
|
||||
const routeContext = getEmbeddingRouteContext(options);
|
||||
const [chunks, contextChunks] = await Promise.all([
|
||||
context.matchWorkspaceAll(options.workspace, query, 10, signal),
|
||||
context.matchWorkspaceAll(
|
||||
options.workspace,
|
||||
query,
|
||||
10,
|
||||
signal,
|
||||
0.8,
|
||||
undefined,
|
||||
0.85,
|
||||
routeContext
|
||||
),
|
||||
sessionId
|
||||
? context
|
||||
.getBySessionId(sessionId)
|
||||
.then(current => current?.matchFiles(query, 10, signal) ?? [])
|
||||
.then(
|
||||
current =>
|
||||
current?.matchFiles(
|
||||
query,
|
||||
10,
|
||||
signal,
|
||||
0.85,
|
||||
0.5,
|
||||
routeContext
|
||||
) ?? []
|
||||
)
|
||||
: [],
|
||||
]);
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ export * from './blob-read';
|
||||
export * from './code-artifact';
|
||||
export * from './conversation-summary';
|
||||
export * from './doc-compose';
|
||||
export * from './doc-edit';
|
||||
export * from './doc-keyword-search';
|
||||
export * from './doc-read';
|
||||
export * from './doc-semantic-search';
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
sniffMime,
|
||||
} from '../../../base';
|
||||
import { Models } from '../../../models';
|
||||
import { ConversationPolicy } from '../conversation/policy';
|
||||
import { CopilotAccessPolicy } from '../access';
|
||||
import { PromptService } from '../prompt';
|
||||
import { CopilotProviderType } from '../providers/types';
|
||||
import { ActionRuntimeBridge } from '../runtime/action-runtime-bridge';
|
||||
@@ -62,7 +62,7 @@ export class CopilotTranscriptionService {
|
||||
private readonly tasks: TaskPolicy,
|
||||
private readonly prompts: PromptService,
|
||||
private readonly actionBridge: ActionRuntimeBridge,
|
||||
@Optional() private readonly conversationPolicy?: ConversationPolicy
|
||||
@Optional() private readonly access?: CopilotAccessPolicy
|
||||
) {}
|
||||
|
||||
private parseTaskPayload(payload: unknown): TranscriptionPayloadV2 {
|
||||
@@ -223,6 +223,12 @@ export class CopilotTranscriptionService {
|
||||
throw new CopilotTranscriptionJobExists();
|
||||
}
|
||||
|
||||
await this.access?.assertQuotaOrByok({
|
||||
userId,
|
||||
workspaceId,
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
|
||||
const { model, strategy } = await this.resolveTranscriptStrategy(
|
||||
userId,
|
||||
input?.strategy ?? undefined
|
||||
@@ -270,6 +276,12 @@ export class CopilotTranscriptionService {
|
||||
);
|
||||
}
|
||||
|
||||
await this.access?.assertQuotaOrByok({
|
||||
userId,
|
||||
workspaceId,
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
|
||||
const payload = this.parseTaskPayload(task.protectedResult);
|
||||
const { model } = await this.resolveTranscriptStrategy(
|
||||
userId,
|
||||
@@ -307,13 +319,17 @@ export class CopilotTranscriptionService {
|
||||
return null;
|
||||
}
|
||||
|
||||
const settled =
|
||||
task.status === 'settled'
|
||||
? task
|
||||
: await (async () => {
|
||||
await this.conversationPolicy?.checkQuota(userId);
|
||||
return await this.models.copilotTranscriptTask.settle(task.id);
|
||||
})();
|
||||
if (task.status === 'settled') {
|
||||
return this.taskToJob(task);
|
||||
}
|
||||
|
||||
await this.access?.assertQuotaOrByok({
|
||||
userId,
|
||||
workspaceId,
|
||||
featureKind: 'transcript',
|
||||
});
|
||||
|
||||
const settled = await this.models.copilotTranscriptTask.settle(task.id);
|
||||
return this.taskToJob(settled);
|
||||
}
|
||||
|
||||
@@ -378,6 +394,13 @@ export class CopilotTranscriptionService {
|
||||
stepId: 'transcribe',
|
||||
modelId,
|
||||
messages,
|
||||
options: {
|
||||
user: task.userId,
|
||||
workspace: task.workspaceId,
|
||||
taskId,
|
||||
billingUnitId: taskId,
|
||||
featureKind: 'transcript',
|
||||
},
|
||||
prefer: CopilotProviderType.Gemini,
|
||||
responseContract: TranscriptActionResultContract,
|
||||
},
|
||||
|
||||
@@ -37,6 +37,7 @@ export const ChatQuerySchema = z
|
||||
.object({
|
||||
messageId: zMaybeString,
|
||||
modelId: zMaybeString,
|
||||
byokLeaseId: zMaybeString,
|
||||
retry: zBool,
|
||||
reasoning: zBool,
|
||||
webSearch: zBool,
|
||||
@@ -47,6 +48,7 @@ export const ChatQuerySchema = z
|
||||
({
|
||||
messageId,
|
||||
modelId,
|
||||
byokLeaseId,
|
||||
retry,
|
||||
reasoning,
|
||||
webSearch,
|
||||
@@ -55,6 +57,7 @@ export const ChatQuerySchema = z
|
||||
}) => ({
|
||||
messageId,
|
||||
modelId,
|
||||
byokLeaseId,
|
||||
retry,
|
||||
reasoning,
|
||||
webSearch,
|
||||
|
||||
@@ -288,6 +288,24 @@ type BlobUploadedPart {
|
||||
partNumber: Int!
|
||||
}
|
||||
|
||||
enum ByokKeyStorage {
|
||||
local
|
||||
server
|
||||
}
|
||||
|
||||
enum ByokKeyTestStatus {
|
||||
failed
|
||||
passed
|
||||
untested
|
||||
}
|
||||
|
||||
enum ByokProvider {
|
||||
anthropic
|
||||
fal
|
||||
gemini
|
||||
openai
|
||||
}
|
||||
|
||||
type CalendarAccountObjectType {
|
||||
calendars: [CalendarSubscriptionObjectType!]!
|
||||
calendarsCount: Int!
|
||||
@@ -731,6 +749,26 @@ input CreateUserInput {
|
||||
password: String
|
||||
}
|
||||
|
||||
input CreateWorkspaceByokLocalLeaseInput {
|
||||
providers: [CreateWorkspaceByokLocalLeaseProviderInput!]!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
input CreateWorkspaceByokLocalLeaseProviderInput {
|
||||
apiKey: String!
|
||||
description: String
|
||||
enabled: Boolean
|
||||
endpoint: String
|
||||
name: String!
|
||||
provider: ByokProvider!
|
||||
sortOrder: SafeInt
|
||||
}
|
||||
|
||||
type CreateWorkspaceByokLocalLeaseResultType {
|
||||
expiresAt: DateTime!
|
||||
leaseId: String!
|
||||
}
|
||||
|
||||
type CredentialsRequirementType {
|
||||
password: PasswordLimitsType!
|
||||
}
|
||||
@@ -1514,9 +1552,6 @@ type Mutation {
|
||||
|
||||
"""Update workspace flags and features for admin"""
|
||||
adminUpdateWorkspace(input: AdminUpdateWorkspaceInput!): AdminWorkspace
|
||||
|
||||
"""Apply updates to a doc using LLM and return the merged markdown."""
|
||||
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String!
|
||||
approveMember(userId: String!, workspaceId: String!): Boolean!
|
||||
|
||||
"""Ban an user"""
|
||||
@@ -1527,6 +1562,7 @@ type Mutation {
|
||||
|
||||
"""Cleanup sessions"""
|
||||
cleanupCopilotSession(options: DeleteSessionInput!): [String!]!
|
||||
clearWorkspaceByokConfigs(provider: ByokProvider, workspaceId: String!): Boolean!
|
||||
completeBlobUpload(key: String!, parts: [BlobUploadPartInput!], uploadId: String, workspaceId: String!): String!
|
||||
createBlobUpload(key: String!, mime: String!, size: Int!, workspaceId: String!): BlobUploadInit!
|
||||
|
||||
@@ -1560,6 +1596,7 @@ type Mutation {
|
||||
|
||||
"""Create a new workspace"""
|
||||
createWorkspace(init: Upload): WorkspaceType!
|
||||
createWorkspaceByokLocalLease(input: CreateWorkspaceByokLocalLeaseInput!): CreateWorkspaceByokLocalLeaseResultType!
|
||||
deactivateLicense(workspaceId: String!): Boolean!
|
||||
deleteAccount: DeleteAccount!
|
||||
deleteBlob(hash: String @deprecated(reason: "use parameter [key]"), key: String, permanently: Boolean! = false, workspaceId: String!): Boolean!
|
||||
@@ -1573,6 +1610,7 @@ type Mutation {
|
||||
"""Delete a user account"""
|
||||
deleteUser(id: String!): DeleteAccount!
|
||||
deleteWorkspace(id: String!): Boolean!
|
||||
deleteWorkspaceByokConfig(id: ID!, workspaceId: String!): Boolean!
|
||||
|
||||
"""Reenable an banned user"""
|
||||
enableUser(id: String!): UserType!
|
||||
@@ -1628,6 +1666,7 @@ type Mutation {
|
||||
"""Remove workspace embedding files"""
|
||||
removeWorkspaceEmbeddingFiles(fileId: String!, workspaceId: String!): Boolean!
|
||||
removeWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Boolean!
|
||||
reorderWorkspaceByokConfigs(input: ReorderWorkspaceByokConfigsInput!): [WorkspaceByokKeyConfigType!]!
|
||||
|
||||
"""Request to apply the subscription in advance"""
|
||||
requestApplySubscription(transactionId: String!): [SubscriptionType!]!
|
||||
@@ -1650,6 +1689,7 @@ type Mutation {
|
||||
setBlob(blob: Upload!, workspaceId: String!): String!
|
||||
settleTranscriptTask(taskId: String!, workspaceId: String!): TranscriptionResultType
|
||||
submitTranscriptTask(blob: Upload, blobId: String!, blobs: [Upload!], input: SubmitAudioTranscriptionInput, workspaceId: String!): TranscriptionResultType
|
||||
testWorkspaceByokConfig(input: TestWorkspaceByokConfigInput!): TestWorkspaceByokConfigResultType!
|
||||
unlinkCalendarAccount(accountId: String!): Boolean!
|
||||
|
||||
"""update app configuration"""
|
||||
@@ -1690,6 +1730,7 @@ type Mutation {
|
||||
|
||||
"""Upload a comment attachment and return the access url"""
|
||||
uploadCommentAttachment(attachment: Upload!, docId: String!, workspaceId: String!): String!
|
||||
upsertWorkspaceByokConfig(input: UpsertWorkspaceByokConfigInput!): WorkspaceByokKeyConfigType!
|
||||
verifyEmail(token: String!): Boolean!
|
||||
}
|
||||
|
||||
@@ -1907,9 +1948,6 @@ type Query {
|
||||
"""get the whole app configuration"""
|
||||
appConfig: JSONObject!
|
||||
|
||||
"""Apply updates to a doc using LLM and return the merged markdown."""
|
||||
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String! @deprecated(reason: "use Mutation.applyDocUpdates")
|
||||
|
||||
"""Get current user"""
|
||||
currentUser: UserType
|
||||
error(name: ErrorNames!): ErrorDataUnion!
|
||||
@@ -2013,6 +2051,12 @@ input RemoveContextFileInput {
|
||||
fileId: String!
|
||||
}
|
||||
|
||||
input ReorderWorkspaceByokConfigsInput {
|
||||
ids: [ID!]!
|
||||
storage: ByokKeyStorage!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
input ReplyCreateInput {
|
||||
commentId: ID!
|
||||
content: JSONObject!
|
||||
@@ -2344,6 +2388,21 @@ enum SubscriptionVariant {
|
||||
Onetime
|
||||
}
|
||||
|
||||
input TestWorkspaceByokConfigInput {
|
||||
apiKey: String
|
||||
configId: ID
|
||||
endpoint: String
|
||||
provider: ByokProvider!
|
||||
storage: ByokKeyStorage!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type TestWorkspaceByokConfigResultType {
|
||||
message: String
|
||||
ok: Boolean!
|
||||
status: ByokKeyTestStatus!
|
||||
}
|
||||
|
||||
enum TimeBucket {
|
||||
Day
|
||||
Minute
|
||||
@@ -2501,6 +2560,19 @@ input UpdateWorkspaceInput {
|
||||
"""The `Upload` scalar type represents a file upload."""
|
||||
scalar Upload
|
||||
|
||||
input UpsertWorkspaceByokConfigInput {
|
||||
apiKey: String
|
||||
description: String
|
||||
enabled: Boolean
|
||||
endpoint: String
|
||||
id: ID
|
||||
name: String!
|
||||
provider: ByokProvider!
|
||||
sortOrder: SafeInt
|
||||
storage: ByokKeyStorage!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type UserImportFailedType {
|
||||
email: String!
|
||||
error: String!
|
||||
@@ -2604,6 +2676,53 @@ type VersionRejectedDataType {
|
||||
version: String!
|
||||
}
|
||||
|
||||
type WorkspaceByokCapabilityWarningType {
|
||||
featureKind: String!
|
||||
reason: String!
|
||||
requiredProviders: [ByokProvider!]!
|
||||
}
|
||||
|
||||
type WorkspaceByokKeyConfigType {
|
||||
capabilities: [String!]!
|
||||
configured: Boolean!
|
||||
description: String
|
||||
disabledReason: String
|
||||
enabled: Boolean!
|
||||
endpoint: String
|
||||
endpointEditable: Boolean!
|
||||
id: ID!
|
||||
lastError: String
|
||||
lastErrorAt: DateTime
|
||||
lastTestError: String
|
||||
lastTestedAt: DateTime
|
||||
lastUsedAt: DateTime
|
||||
name: String!
|
||||
provider: ByokProvider!
|
||||
sortOrder: SafeInt!
|
||||
storage: ByokKeyStorage!
|
||||
testStatus: ByokKeyTestStatus!
|
||||
}
|
||||
|
||||
type WorkspaceByokSettingsType {
|
||||
allowedProviders: [ByokProvider!]!
|
||||
customEndpointSupported: Boolean!
|
||||
entitled: Boolean!
|
||||
entitlementRequired: [String!]!
|
||||
hasAiPlan: Boolean!
|
||||
keys: [WorkspaceByokKeyConfigType!]!
|
||||
localEntitled: Boolean!
|
||||
localStorageSupported: Boolean!
|
||||
serverEntitled: Boolean!
|
||||
warnings: [WorkspaceByokCapabilityWarningType!]!
|
||||
workspaceId: String!
|
||||
}
|
||||
|
||||
type WorkspaceByokUsagePointType {
|
||||
date: DateTime!
|
||||
featureKind: String!
|
||||
totalTokens: SafeInt!
|
||||
}
|
||||
|
||||
input WorkspaceCalendarItemInput {
|
||||
colorOverride: String
|
||||
sortOrder: Int
|
||||
@@ -2722,6 +2841,8 @@ type WorkspaceType {
|
||||
|
||||
"""Blobs size of workspace"""
|
||||
blobsSize: Int!
|
||||
byokSettings: WorkspaceByokSettingsType!
|
||||
byokUsage(from: DateTime!, to: DateTime!): [WorkspaceByokUsagePointType!]!
|
||||
calendars: [WorkspaceCalendarObjectType!]!
|
||||
|
||||
"""Get comment changes of a doc"""
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
mutation applyDocUpdates(
|
||||
$workspaceId: String!
|
||||
$docId: String!
|
||||
$op: String!
|
||||
$updates: String!
|
||||
) {
|
||||
applyDocUpdates(
|
||||
workspaceId: $workspaceId
|
||||
docId: $docId
|
||||
op: $op
|
||||
updates: $updates
|
||||
)
|
||||
}
|
||||
@@ -1033,19 +1033,6 @@ export const uploadCommentAttachmentMutation = {
|
||||
file: true,
|
||||
};
|
||||
|
||||
export const applyDocUpdatesMutation = {
|
||||
id: 'applyDocUpdatesMutation' as const,
|
||||
op: 'applyDocUpdates',
|
||||
query: `mutation applyDocUpdates($workspaceId: String!, $docId: String!, $op: String!, $updates: String!) {
|
||||
applyDocUpdates(
|
||||
workspaceId: $workspaceId
|
||||
docId: $docId
|
||||
op: $op
|
||||
updates: $updates
|
||||
)
|
||||
}`,
|
||||
};
|
||||
|
||||
export const addContextBlobMutation = {
|
||||
id: 'addContextBlobMutation' as const,
|
||||
op: 'addContextBlob',
|
||||
@@ -2997,6 +2984,117 @@ export const workspaceBlobQuotaQuery = {
|
||||
}`,
|
||||
};
|
||||
|
||||
export const clearWorkspaceByokConfigsMutation = {
|
||||
id: 'clearWorkspaceByokConfigsMutation' as const,
|
||||
op: 'clearWorkspaceByokConfigs',
|
||||
query: `mutation clearWorkspaceByokConfigs($workspaceId: String!) {
|
||||
clearWorkspaceByokConfigs(workspaceId: $workspaceId)
|
||||
}`,
|
||||
};
|
||||
|
||||
export const deleteWorkspaceByokConfigMutation = {
|
||||
id: 'deleteWorkspaceByokConfigMutation' as const,
|
||||
op: 'deleteWorkspaceByokConfig',
|
||||
query: `mutation deleteWorkspaceByokConfig($workspaceId: String!, $id: ID!) {
|
||||
deleteWorkspaceByokConfig(workspaceId: $workspaceId, id: $id)
|
||||
}`,
|
||||
};
|
||||
|
||||
export const reorderWorkspaceByokConfigsMutation = {
|
||||
id: 'reorderWorkspaceByokConfigsMutation' as const,
|
||||
op: 'reorderWorkspaceByokConfigs',
|
||||
query: `mutation reorderWorkspaceByokConfigs($input: ReorderWorkspaceByokConfigsInput!) {
|
||||
reorderWorkspaceByokConfigs(input: $input) {
|
||||
id
|
||||
sortOrder
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const testWorkspaceByokConfigMutation = {
|
||||
id: 'testWorkspaceByokConfigMutation' as const,
|
||||
op: 'testWorkspaceByokConfig',
|
||||
query: `mutation testWorkspaceByokConfig($input: TestWorkspaceByokConfigInput!) {
|
||||
testWorkspaceByokConfig(input: $input) {
|
||||
ok
|
||||
status
|
||||
message
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const upsertWorkspaceByokConfigMutation = {
|
||||
id: 'upsertWorkspaceByokConfigMutation' as const,
|
||||
op: 'upsertWorkspaceByokConfig',
|
||||
query: `mutation upsertWorkspaceByokConfig($input: UpsertWorkspaceByokConfigInput!) {
|
||||
upsertWorkspaceByokConfig(input: $input) {
|
||||
id
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const createWorkspaceByokLocalLeaseMutation = {
|
||||
id: 'createWorkspaceByokLocalLeaseMutation' as const,
|
||||
op: 'createWorkspaceByokLocalLease',
|
||||
query: `mutation createWorkspaceByokLocalLease($input: CreateWorkspaceByokLocalLeaseInput!) {
|
||||
createWorkspaceByokLocalLease(input: $input) {
|
||||
leaseId
|
||||
expiresAt
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const workspaceByokSettingsQuery = {
|
||||
id: 'workspaceByokSettingsQuery' as const,
|
||||
op: 'workspaceByokSettings',
|
||||
query: `query workspaceByokSettings($id: String!, $from: DateTime!, $to: DateTime!) {
|
||||
workspace(id: $id) {
|
||||
id
|
||||
byokSettings {
|
||||
workspaceId
|
||||
entitled
|
||||
serverEntitled
|
||||
localEntitled
|
||||
entitlementRequired
|
||||
allowedProviders
|
||||
localStorageSupported
|
||||
customEndpointSupported
|
||||
hasAiPlan
|
||||
keys {
|
||||
id
|
||||
provider
|
||||
name
|
||||
description
|
||||
storage
|
||||
configured
|
||||
enabled
|
||||
endpoint
|
||||
endpointEditable
|
||||
sortOrder
|
||||
capabilities
|
||||
testStatus
|
||||
disabledReason
|
||||
lastTestedAt
|
||||
lastTestError
|
||||
lastUsedAt
|
||||
lastErrorAt
|
||||
lastError
|
||||
}
|
||||
warnings {
|
||||
featureKind
|
||||
reason
|
||||
requiredProviders
|
||||
}
|
||||
}
|
||||
byokUsage(from: $from, to: $to) {
|
||||
date
|
||||
featureKind
|
||||
totalTokens
|
||||
}
|
||||
}
|
||||
}`,
|
||||
};
|
||||
|
||||
export const getWorkspaceConfigQuery = {
|
||||
id: 'getWorkspaceConfigQuery' as const,
|
||||
op: 'getWorkspaceConfig',
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
mutation clearWorkspaceByokConfigs($workspaceId: String!) {
|
||||
clearWorkspaceByokConfigs(workspaceId: $workspaceId)
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mutation deleteWorkspaceByokConfig($workspaceId: String!, $id: ID!) {
|
||||
deleteWorkspaceByokConfig(workspaceId: $workspaceId, id: $id)
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mutation reorderWorkspaceByokConfigs(
|
||||
$input: ReorderWorkspaceByokConfigsInput!
|
||||
) {
|
||||
reorderWorkspaceByokConfigs(input: $input) {
|
||||
id
|
||||
sortOrder
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
mutation testWorkspaceByokConfig($input: TestWorkspaceByokConfigInput!) {
|
||||
testWorkspaceByokConfig(input: $input) {
|
||||
ok
|
||||
status
|
||||
message
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mutation upsertWorkspaceByokConfig($input: UpsertWorkspaceByokConfigInput!) {
|
||||
upsertWorkspaceByokConfig(input: $input) {
|
||||
id
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
mutation createWorkspaceByokLocalLease(
|
||||
$input: CreateWorkspaceByokLocalLeaseInput!
|
||||
) {
|
||||
createWorkspaceByokLocalLease(input: $input) {
|
||||
leaseId
|
||||
expiresAt
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
query workspaceByokSettings($id: String!, $from: DateTime!, $to: DateTime!) {
|
||||
workspace(id: $id) {
|
||||
id
|
||||
byokSettings {
|
||||
workspaceId
|
||||
entitled
|
||||
serverEntitled
|
||||
localEntitled
|
||||
entitlementRequired
|
||||
allowedProviders
|
||||
localStorageSupported
|
||||
customEndpointSupported
|
||||
hasAiPlan
|
||||
keys {
|
||||
id
|
||||
provider
|
||||
name
|
||||
description
|
||||
storage
|
||||
configured
|
||||
enabled
|
||||
endpoint
|
||||
endpointEditable
|
||||
sortOrder
|
||||
capabilities
|
||||
testStatus
|
||||
disabledReason
|
||||
lastTestedAt
|
||||
lastTestError
|
||||
lastUsedAt
|
||||
lastErrorAt
|
||||
lastError
|
||||
}
|
||||
warnings {
|
||||
featureKind
|
||||
reason
|
||||
requiredProviders
|
||||
}
|
||||
}
|
||||
byokUsage(from: $from, to: $to) {
|
||||
date
|
||||
featureKind
|
||||
totalTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -347,6 +347,24 @@ export interface BlobUploadedPart {
|
||||
partNumber: Scalars['Int']['output'];
|
||||
}
|
||||
|
||||
export enum ByokKeyStorage {
|
||||
local = 'local',
|
||||
server = 'server',
|
||||
}
|
||||
|
||||
export enum ByokKeyTestStatus {
|
||||
failed = 'failed',
|
||||
passed = 'passed',
|
||||
untested = 'untested',
|
||||
}
|
||||
|
||||
export enum ByokProvider {
|
||||
anthropic = 'anthropic',
|
||||
fal = 'fal',
|
||||
gemini = 'gemini',
|
||||
openai = 'openai',
|
||||
}
|
||||
|
||||
export interface CalendarAccountObjectType {
|
||||
__typename?: 'CalendarAccountObjectType';
|
||||
calendars: Array<CalendarSubscriptionObjectType>;
|
||||
@@ -868,6 +886,27 @@ export interface CreateUserInput {
|
||||
password?: InputMaybe<Scalars['String']['input']>;
|
||||
}
|
||||
|
||||
export interface CreateWorkspaceByokLocalLeaseInput {
|
||||
providers: Array<CreateWorkspaceByokLocalLeaseProviderInput>;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface CreateWorkspaceByokLocalLeaseProviderInput {
|
||||
apiKey: Scalars['String']['input'];
|
||||
description?: InputMaybe<Scalars['String']['input']>;
|
||||
enabled?: InputMaybe<Scalars['Boolean']['input']>;
|
||||
endpoint?: InputMaybe<Scalars['String']['input']>;
|
||||
name: Scalars['String']['input'];
|
||||
provider: ByokProvider;
|
||||
sortOrder?: InputMaybe<Scalars['SafeInt']['input']>;
|
||||
}
|
||||
|
||||
export interface CreateWorkspaceByokLocalLeaseResultType {
|
||||
__typename?: 'CreateWorkspaceByokLocalLeaseResultType';
|
||||
expiresAt: Scalars['DateTime']['output'];
|
||||
leaseId: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface CredentialsRequirementType {
|
||||
__typename?: 'CredentialsRequirementType';
|
||||
password: PasswordLimitsType;
|
||||
@@ -1727,8 +1766,6 @@ export interface Mutation {
|
||||
addWorkspaceFeature: Scalars['Boolean']['output'];
|
||||
/** Update workspace flags and features for admin */
|
||||
adminUpdateWorkspace: Maybe<AdminWorkspace>;
|
||||
/** Apply updates to a doc using LLM and return the merged markdown. */
|
||||
applyDocUpdates: Scalars['String']['output'];
|
||||
approveMember: Scalars['Boolean']['output'];
|
||||
/** Ban an user */
|
||||
banUser: UserType;
|
||||
@@ -1737,6 +1774,7 @@ export interface Mutation {
|
||||
changePassword: Scalars['Boolean']['output'];
|
||||
/** Cleanup sessions */
|
||||
cleanupCopilotSession: Array<Scalars['String']['output']>;
|
||||
clearWorkspaceByokConfigs: Scalars['Boolean']['output'];
|
||||
completeBlobUpload: Scalars['String']['output'];
|
||||
createBlobUpload: BlobUploadInit;
|
||||
/** Create change password url */
|
||||
@@ -1764,6 +1802,7 @@ export interface Mutation {
|
||||
createUser: UserType;
|
||||
/** Create a new workspace */
|
||||
createWorkspace: WorkspaceType;
|
||||
createWorkspaceByokLocalLease: CreateWorkspaceByokLocalLeaseResultType;
|
||||
deactivateLicense: Scalars['Boolean']['output'];
|
||||
deleteAccount: DeleteAccount;
|
||||
deleteBlob: Scalars['Boolean']['output'];
|
||||
@@ -1774,6 +1813,7 @@ export interface Mutation {
|
||||
/** Delete a user account */
|
||||
deleteUser: DeleteAccount;
|
||||
deleteWorkspace: Scalars['Boolean']['output'];
|
||||
deleteWorkspaceByokConfig: Scalars['Boolean']['output'];
|
||||
/** Reenable an banned user */
|
||||
enableUser: UserType;
|
||||
/** Create a chat session */
|
||||
@@ -1815,6 +1855,7 @@ export interface Mutation {
|
||||
/** Remove workspace embedding files */
|
||||
removeWorkspaceEmbeddingFiles: Scalars['Boolean']['output'];
|
||||
removeWorkspaceFeature: Scalars['Boolean']['output'];
|
||||
reorderWorkspaceByokConfigs: Array<WorkspaceByokKeyConfigType>;
|
||||
/** Request to apply the subscription in advance */
|
||||
requestApplySubscription: Array<SubscriptionType>;
|
||||
/** Resolve a comment or not */
|
||||
@@ -1835,6 +1876,7 @@ export interface Mutation {
|
||||
setBlob: Scalars['String']['output'];
|
||||
settleTranscriptTask: Maybe<TranscriptionResultType>;
|
||||
submitTranscriptTask: Maybe<TranscriptionResultType>;
|
||||
testWorkspaceByokConfig: TestWorkspaceByokConfigResultType;
|
||||
unlinkCalendarAccount: Scalars['Boolean']['output'];
|
||||
/** update app configuration */
|
||||
updateAppConfig: Scalars['JSONObject']['output'];
|
||||
@@ -1864,6 +1906,7 @@ export interface Mutation {
|
||||
uploadAvatar: UserType;
|
||||
/** Upload a comment attachment and return the access url */
|
||||
uploadCommentAttachment: Scalars['String']['output'];
|
||||
upsertWorkspaceByokConfig: WorkspaceByokKeyConfigType;
|
||||
verifyEmail: Scalars['Boolean']['output'];
|
||||
}
|
||||
|
||||
@@ -1915,13 +1958,6 @@ export interface MutationAdminUpdateWorkspaceArgs {
|
||||
input: AdminUpdateWorkspaceInput;
|
||||
}
|
||||
|
||||
export interface MutationApplyDocUpdatesArgs {
|
||||
docId: Scalars['String']['input'];
|
||||
op: Scalars['String']['input'];
|
||||
updates: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationApproveMemberArgs {
|
||||
userId: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
@@ -1952,6 +1988,11 @@ export interface MutationCleanupCopilotSessionArgs {
|
||||
options: DeleteSessionInput;
|
||||
}
|
||||
|
||||
export interface MutationClearWorkspaceByokConfigsArgs {
|
||||
provider?: InputMaybe<ByokProvider>;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationCompleteBlobUploadArgs {
|
||||
key: Scalars['String']['input'];
|
||||
parts?: InputMaybe<Array<BlobUploadPartInput>>;
|
||||
@@ -2017,6 +2058,10 @@ export interface MutationCreateWorkspaceArgs {
|
||||
init?: InputMaybe<Scalars['Upload']['input']>;
|
||||
}
|
||||
|
||||
export interface MutationCreateWorkspaceByokLocalLeaseArgs {
|
||||
input: CreateWorkspaceByokLocalLeaseInput;
|
||||
}
|
||||
|
||||
export interface MutationDeactivateLicenseArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2044,6 +2089,11 @@ export interface MutationDeleteWorkspaceArgs {
|
||||
id: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationDeleteWorkspaceByokConfigArgs {
|
||||
id: Scalars['ID']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationEnableUserArgs {
|
||||
id: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2153,6 +2203,10 @@ export interface MutationRemoveWorkspaceFeatureArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationReorderWorkspaceByokConfigsArgs {
|
||||
input: ReorderWorkspaceByokConfigsInput;
|
||||
}
|
||||
|
||||
export interface MutationRequestApplySubscriptionArgs {
|
||||
transactionId: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2241,6 +2295,10 @@ export interface MutationSubmitTranscriptTaskArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationTestWorkspaceByokConfigArgs {
|
||||
input: TestWorkspaceByokConfigInput;
|
||||
}
|
||||
|
||||
export interface MutationUnlinkCalendarAccountArgs {
|
||||
accountId: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2323,6 +2381,10 @@ export interface MutationUploadCommentAttachmentArgs {
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface MutationUpsertWorkspaceByokConfigArgs {
|
||||
input: UpsertWorkspaceByokConfigInput;
|
||||
}
|
||||
|
||||
export interface MutationVerifyEmailArgs {
|
||||
token: Scalars['String']['input'];
|
||||
}
|
||||
@@ -2545,11 +2607,6 @@ export interface Query {
|
||||
adminWorkspacesCount: Scalars['Int']['output'];
|
||||
/** get the whole app configuration */
|
||||
appConfig: Scalars['JSONObject']['output'];
|
||||
/**
|
||||
* Apply updates to a doc using LLM and return the merged markdown.
|
||||
* @deprecated use Mutation.applyDocUpdates
|
||||
*/
|
||||
applyDocUpdates: Scalars['String']['output'];
|
||||
/** Get current user */
|
||||
currentUser: Maybe<UserType>;
|
||||
error: ErrorDataUnion;
|
||||
@@ -2608,13 +2665,6 @@ export interface QueryAdminWorkspacesCountArgs {
|
||||
filter: ListWorkspaceInput;
|
||||
}
|
||||
|
||||
export interface QueryApplyDocUpdatesArgs {
|
||||
docId: Scalars['String']['input'];
|
||||
op: Scalars['String']['input'];
|
||||
updates: Scalars['String']['input'];
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface QueryErrorArgs {
|
||||
name: ErrorNames;
|
||||
}
|
||||
@@ -2723,6 +2773,12 @@ export interface RemoveContextFileInput {
|
||||
fileId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface ReorderWorkspaceByokConfigsInput {
|
||||
ids: Array<Scalars['ID']['input']>;
|
||||
storage: ByokKeyStorage;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface ReplyCreateInput {
|
||||
commentId: Scalars['ID']['input'];
|
||||
content: Scalars['JSONObject']['input'];
|
||||
@@ -3046,6 +3102,22 @@ export enum SubscriptionVariant {
|
||||
Onetime = 'Onetime',
|
||||
}
|
||||
|
||||
export interface TestWorkspaceByokConfigInput {
|
||||
apiKey?: InputMaybe<Scalars['String']['input']>;
|
||||
configId?: InputMaybe<Scalars['ID']['input']>;
|
||||
endpoint?: InputMaybe<Scalars['String']['input']>;
|
||||
provider: ByokProvider;
|
||||
storage: ByokKeyStorage;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface TestWorkspaceByokConfigResultType {
|
||||
__typename?: 'TestWorkspaceByokConfigResultType';
|
||||
message: Maybe<Scalars['String']['output']>;
|
||||
ok: Scalars['Boolean']['output'];
|
||||
status: ByokKeyTestStatus;
|
||||
}
|
||||
|
||||
export enum TimeBucket {
|
||||
Day = 'Day',
|
||||
Minute = 'Minute',
|
||||
@@ -3208,6 +3280,19 @@ export interface UpdateWorkspaceInput {
|
||||
public?: InputMaybe<Scalars['Boolean']['input']>;
|
||||
}
|
||||
|
||||
export interface UpsertWorkspaceByokConfigInput {
|
||||
apiKey?: InputMaybe<Scalars['String']['input']>;
|
||||
description?: InputMaybe<Scalars['String']['input']>;
|
||||
enabled?: InputMaybe<Scalars['Boolean']['input']>;
|
||||
endpoint?: InputMaybe<Scalars['String']['input']>;
|
||||
id?: InputMaybe<Scalars['ID']['input']>;
|
||||
name: Scalars['String']['input'];
|
||||
provider: ByokProvider;
|
||||
sortOrder?: InputMaybe<Scalars['SafeInt']['input']>;
|
||||
storage: ByokKeyStorage;
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface UserImportFailedType {
|
||||
__typename?: 'UserImportFailedType';
|
||||
email: Scalars['String']['output'];
|
||||
@@ -3323,6 +3408,57 @@ export interface VersionRejectedDataType {
|
||||
version: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface WorkspaceByokCapabilityWarningType {
|
||||
__typename?: 'WorkspaceByokCapabilityWarningType';
|
||||
featureKind: Scalars['String']['output'];
|
||||
reason: Scalars['String']['output'];
|
||||
requiredProviders: Array<ByokProvider>;
|
||||
}
|
||||
|
||||
export interface WorkspaceByokKeyConfigType {
|
||||
__typename?: 'WorkspaceByokKeyConfigType';
|
||||
capabilities: Array<Scalars['String']['output']>;
|
||||
configured: Scalars['Boolean']['output'];
|
||||
description: Maybe<Scalars['String']['output']>;
|
||||
disabledReason: Maybe<Scalars['String']['output']>;
|
||||
enabled: Scalars['Boolean']['output'];
|
||||
endpoint: Maybe<Scalars['String']['output']>;
|
||||
endpointEditable: Scalars['Boolean']['output'];
|
||||
id: Scalars['ID']['output'];
|
||||
lastError: Maybe<Scalars['String']['output']>;
|
||||
lastErrorAt: Maybe<Scalars['DateTime']['output']>;
|
||||
lastTestError: Maybe<Scalars['String']['output']>;
|
||||
lastTestedAt: Maybe<Scalars['DateTime']['output']>;
|
||||
lastUsedAt: Maybe<Scalars['DateTime']['output']>;
|
||||
name: Scalars['String']['output'];
|
||||
provider: ByokProvider;
|
||||
sortOrder: Scalars['SafeInt']['output'];
|
||||
storage: ByokKeyStorage;
|
||||
testStatus: ByokKeyTestStatus;
|
||||
}
|
||||
|
||||
export interface WorkspaceByokSettingsType {
|
||||
__typename?: 'WorkspaceByokSettingsType';
|
||||
allowedProviders: Array<ByokProvider>;
|
||||
customEndpointSupported: Scalars['Boolean']['output'];
|
||||
entitled: Scalars['Boolean']['output'];
|
||||
entitlementRequired: Array<Scalars['String']['output']>;
|
||||
hasAiPlan: Scalars['Boolean']['output'];
|
||||
keys: Array<WorkspaceByokKeyConfigType>;
|
||||
localEntitled: Scalars['Boolean']['output'];
|
||||
localStorageSupported: Scalars['Boolean']['output'];
|
||||
serverEntitled: Scalars['Boolean']['output'];
|
||||
warnings: Array<WorkspaceByokCapabilityWarningType>;
|
||||
workspaceId: Scalars['String']['output'];
|
||||
}
|
||||
|
||||
export interface WorkspaceByokUsagePointType {
|
||||
__typename?: 'WorkspaceByokUsagePointType';
|
||||
date: Scalars['DateTime']['output'];
|
||||
featureKind: Scalars['String']['output'];
|
||||
totalTokens: Scalars['SafeInt']['output'];
|
||||
}
|
||||
|
||||
export interface WorkspaceCalendarItemInput {
|
||||
colorOverride?: InputMaybe<Scalars['String']['input']>;
|
||||
sortOrder?: InputMaybe<Scalars['Int']['input']>;
|
||||
@@ -3453,6 +3589,8 @@ export interface WorkspaceType {
|
||||
blobs: Array<ListedBlob>;
|
||||
/** Blobs size of workspace */
|
||||
blobsSize: Scalars['Int']['output'];
|
||||
byokSettings: WorkspaceByokSettingsType;
|
||||
byokUsage: Array<WorkspaceByokUsagePointType>;
|
||||
calendars: Array<WorkspaceCalendarObjectType>;
|
||||
/** Get comment changes of a doc */
|
||||
commentChanges: PaginatedCommentChangeObjectType;
|
||||
@@ -3526,6 +3664,11 @@ export interface WorkspaceTypeBlobUploadPartUrlArgs {
|
||||
uploadId: Scalars['String']['input'];
|
||||
}
|
||||
|
||||
export interface WorkspaceTypeByokUsageArgs {
|
||||
from: Scalars['DateTime']['input'];
|
||||
to: Scalars['DateTime']['input'];
|
||||
}
|
||||
|
||||
export interface WorkspaceTypeCommentChangesArgs {
|
||||
docId: Scalars['String']['input'];
|
||||
pagination: PaginationInput;
|
||||
@@ -4642,18 +4785,6 @@ export type UploadCommentAttachmentMutation = {
|
||||
uploadCommentAttachment: string;
|
||||
};
|
||||
|
||||
export type ApplyDocUpdatesMutationVariables = Exact<{
|
||||
workspaceId: Scalars['String']['input'];
|
||||
docId: Scalars['String']['input'];
|
||||
op: Scalars['String']['input'];
|
||||
updates: Scalars['String']['input'];
|
||||
}>;
|
||||
|
||||
export type ApplyDocUpdatesMutation = {
|
||||
__typename?: 'Mutation';
|
||||
applyDocUpdates: string;
|
||||
};
|
||||
|
||||
export type AddContextBlobMutationVariables = Exact<{
|
||||
options: AddContextBlobInput;
|
||||
}>;
|
||||
@@ -7478,6 +7609,136 @@ export type WorkspaceBlobQuotaQuery = {
|
||||
};
|
||||
};
|
||||
|
||||
export type ClearWorkspaceByokConfigsMutationVariables = Exact<{
|
||||
workspaceId: Scalars['String']['input'];
|
||||
}>;
|
||||
|
||||
export type ClearWorkspaceByokConfigsMutation = {
|
||||
__typename?: 'Mutation';
|
||||
clearWorkspaceByokConfigs: boolean;
|
||||
};
|
||||
|
||||
export type DeleteWorkspaceByokConfigMutationVariables = Exact<{
|
||||
workspaceId: Scalars['String']['input'];
|
||||
id: Scalars['ID']['input'];
|
||||
}>;
|
||||
|
||||
export type DeleteWorkspaceByokConfigMutation = {
|
||||
__typename?: 'Mutation';
|
||||
deleteWorkspaceByokConfig: boolean;
|
||||
};
|
||||
|
||||
export type ReorderWorkspaceByokConfigsMutationVariables = Exact<{
|
||||
input: ReorderWorkspaceByokConfigsInput;
|
||||
}>;
|
||||
|
||||
export type ReorderWorkspaceByokConfigsMutation = {
|
||||
__typename?: 'Mutation';
|
||||
reorderWorkspaceByokConfigs: Array<{
|
||||
__typename?: 'WorkspaceByokKeyConfigType';
|
||||
id: string;
|
||||
sortOrder: number;
|
||||
}>;
|
||||
};
|
||||
|
||||
export type TestWorkspaceByokConfigMutationVariables = Exact<{
|
||||
input: TestWorkspaceByokConfigInput;
|
||||
}>;
|
||||
|
||||
export type TestWorkspaceByokConfigMutation = {
|
||||
__typename?: 'Mutation';
|
||||
testWorkspaceByokConfig: {
|
||||
__typename?: 'TestWorkspaceByokConfigResultType';
|
||||
ok: boolean;
|
||||
status: ByokKeyTestStatus;
|
||||
message: string | null;
|
||||
};
|
||||
};
|
||||
|
||||
export type UpsertWorkspaceByokConfigMutationVariables = Exact<{
|
||||
input: UpsertWorkspaceByokConfigInput;
|
||||
}>;
|
||||
|
||||
export type UpsertWorkspaceByokConfigMutation = {
|
||||
__typename?: 'Mutation';
|
||||
upsertWorkspaceByokConfig: {
|
||||
__typename?: 'WorkspaceByokKeyConfigType';
|
||||
id: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type CreateWorkspaceByokLocalLeaseMutationVariables = Exact<{
|
||||
input: CreateWorkspaceByokLocalLeaseInput;
|
||||
}>;
|
||||
|
||||
export type CreateWorkspaceByokLocalLeaseMutation = {
|
||||
__typename?: 'Mutation';
|
||||
createWorkspaceByokLocalLease: {
|
||||
__typename?: 'CreateWorkspaceByokLocalLeaseResultType';
|
||||
leaseId: string;
|
||||
expiresAt: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type WorkspaceByokSettingsQueryVariables = Exact<{
|
||||
id: Scalars['String']['input'];
|
||||
from: Scalars['DateTime']['input'];
|
||||
to: Scalars['DateTime']['input'];
|
||||
}>;
|
||||
|
||||
export type WorkspaceByokSettingsQuery = {
|
||||
__typename?: 'Query';
|
||||
workspace: {
|
||||
__typename?: 'WorkspaceType';
|
||||
id: string;
|
||||
byokSettings: {
|
||||
__typename?: 'WorkspaceByokSettingsType';
|
||||
workspaceId: string;
|
||||
entitled: boolean;
|
||||
serverEntitled: boolean;
|
||||
localEntitled: boolean;
|
||||
entitlementRequired: Array<string>;
|
||||
allowedProviders: Array<ByokProvider>;
|
||||
localStorageSupported: boolean;
|
||||
customEndpointSupported: boolean;
|
||||
hasAiPlan: boolean;
|
||||
keys: Array<{
|
||||
__typename?: 'WorkspaceByokKeyConfigType';
|
||||
id: string;
|
||||
provider: ByokProvider;
|
||||
name: string;
|
||||
description: string | null;
|
||||
storage: ByokKeyStorage;
|
||||
configured: boolean;
|
||||
enabled: boolean;
|
||||
endpoint: string | null;
|
||||
endpointEditable: boolean;
|
||||
sortOrder: number;
|
||||
capabilities: Array<string>;
|
||||
testStatus: ByokKeyTestStatus;
|
||||
disabledReason: string | null;
|
||||
lastTestedAt: string | null;
|
||||
lastTestError: string | null;
|
||||
lastUsedAt: string | null;
|
||||
lastErrorAt: string | null;
|
||||
lastError: string | null;
|
||||
}>;
|
||||
warnings: Array<{
|
||||
__typename?: 'WorkspaceByokCapabilityWarningType';
|
||||
featureKind: string;
|
||||
reason: string;
|
||||
requiredProviders: Array<ByokProvider>;
|
||||
}>;
|
||||
};
|
||||
byokUsage: Array<{
|
||||
__typename?: 'WorkspaceByokUsagePointType';
|
||||
date: string;
|
||||
featureKind: string;
|
||||
totalTokens: number;
|
||||
}>;
|
||||
};
|
||||
};
|
||||
|
||||
export type GetWorkspaceConfigQueryVariables = Exact<{
|
||||
id: Scalars['String']['input'];
|
||||
}>;
|
||||
@@ -8104,6 +8365,11 @@ export type Queries =
|
||||
variables: WorkspaceBlobQuotaQueryVariables;
|
||||
response: WorkspaceBlobQuotaQuery;
|
||||
}
|
||||
| {
|
||||
name: 'workspaceByokSettingsQuery';
|
||||
variables: WorkspaceByokSettingsQueryVariables;
|
||||
response: WorkspaceByokSettingsQuery;
|
||||
}
|
||||
| {
|
||||
name: 'getWorkspaceConfigQuery';
|
||||
variables: GetWorkspaceConfigQueryVariables;
|
||||
@@ -8301,11 +8567,6 @@ export type Mutations =
|
||||
variables: UploadCommentAttachmentMutationVariables;
|
||||
response: UploadCommentAttachmentMutation;
|
||||
}
|
||||
| {
|
||||
name: 'applyDocUpdatesMutation';
|
||||
variables: ApplyDocUpdatesMutationVariables;
|
||||
response: ApplyDocUpdatesMutation;
|
||||
}
|
||||
| {
|
||||
name: 'addContextBlobMutation';
|
||||
variables: AddContextBlobMutationVariables;
|
||||
@@ -8606,6 +8867,36 @@ export type Mutations =
|
||||
variables: VerifyEmailMutationVariables;
|
||||
response: VerifyEmailMutation;
|
||||
}
|
||||
| {
|
||||
name: 'clearWorkspaceByokConfigsMutation';
|
||||
variables: ClearWorkspaceByokConfigsMutationVariables;
|
||||
response: ClearWorkspaceByokConfigsMutation;
|
||||
}
|
||||
| {
|
||||
name: 'deleteWorkspaceByokConfigMutation';
|
||||
variables: DeleteWorkspaceByokConfigMutationVariables;
|
||||
response: DeleteWorkspaceByokConfigMutation;
|
||||
}
|
||||
| {
|
||||
name: 'reorderWorkspaceByokConfigsMutation';
|
||||
variables: ReorderWorkspaceByokConfigsMutationVariables;
|
||||
response: ReorderWorkspaceByokConfigsMutation;
|
||||
}
|
||||
| {
|
||||
name: 'testWorkspaceByokConfigMutation';
|
||||
variables: TestWorkspaceByokConfigMutationVariables;
|
||||
response: TestWorkspaceByokConfigMutation;
|
||||
}
|
||||
| {
|
||||
name: 'upsertWorkspaceByokConfigMutation';
|
||||
variables: UpsertWorkspaceByokConfigMutationVariables;
|
||||
response: UpsertWorkspaceByokConfigMutation;
|
||||
}
|
||||
| {
|
||||
name: 'createWorkspaceByokLocalLeaseMutation';
|
||||
variables: CreateWorkspaceByokLocalLeaseMutationVariables;
|
||||
response: CreateWorkspaceByokLocalLeaseMutation;
|
||||
}
|
||||
| {
|
||||
name: 'setEnableAiMutation';
|
||||
variables: SetEnableAiMutationVariables;
|
||||
|
||||
@@ -313,6 +313,18 @@
|
||||
"type": "Boolean",
|
||||
"desc": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>"
|
||||
},
|
||||
"byok.enabled": {
|
||||
"type": "Boolean",
|
||||
"desc": "Whether to enable workspace BYOK."
|
||||
},
|
||||
"byok.allowedProviders": {
|
||||
"type": "Array",
|
||||
"desc": "The allowlist for workspace BYOK providers."
|
||||
},
|
||||
"byok.allowCustomEndpoint": {
|
||||
"type": "Boolean",
|
||||
"desc": "Whether workspace BYOK custom endpoints are accepted."
|
||||
},
|
||||
"providers.profiles": {
|
||||
"type": "Array",
|
||||
"desc": "The profile list for copilot providers."
|
||||
@@ -342,10 +354,6 @@
|
||||
"type": "Object",
|
||||
"desc": "The config for the gemini provider in Google Vertex AI."
|
||||
},
|
||||
"providers.perplexity": {
|
||||
"type": "Object",
|
||||
"desc": "The config for the perplexity provider."
|
||||
},
|
||||
"providers.anthropic": {
|
||||
"type": "Object",
|
||||
"desc": "The config for the anthropic provider."
|
||||
@@ -354,10 +362,6 @@
|
||||
"type": "Object",
|
||||
"desc": "The config for the anthropic provider in Google Vertex AI."
|
||||
},
|
||||
"providers.morph": {
|
||||
"type": "Object",
|
||||
"desc": "The config for the morph provider."
|
||||
},
|
||||
"unsplash": {
|
||||
"type": "Object",
|
||||
"desc": "The config for the unsplash key."
|
||||
|
||||
@@ -153,7 +153,6 @@ export const KNOWN_CONFIG_GROUPS = [
|
||||
'scenarios',
|
||||
'providers.openai',
|
||||
'providers.gemini',
|
||||
'providers.perplexity',
|
||||
'providers.anthropic',
|
||||
'providers.fal',
|
||||
'unsplash',
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
import path from 'node:path';
|
||||
|
||||
import { app, safeStorage } from 'electron';
|
||||
|
||||
import { PersistentJSONFileStorage } from '../shared-storage/json-file';
|
||||
import type { NamespaceHandlers } from '../type';
|
||||
|
||||
const byokStorage = new PersistentJSONFileStorage(
|
||||
path.join(app.getPath('userData'), 'workspace-byok-keys.json')
|
||||
);
|
||||
|
||||
export function disposeWorkspaceByokStorage() {
|
||||
byokStorage.dispose();
|
||||
}
|
||||
|
||||
const allowedProviders = new Set(['openai', 'anthropic', 'gemini', 'fal']);
|
||||
|
||||
type WorkspaceByokKey = {
|
||||
id: string;
|
||||
provider: 'openai' | 'anthropic' | 'gemini' | 'fal';
|
||||
name: string;
|
||||
description?: string | null;
|
||||
apiKey: string;
|
||||
endpoint?: string | null;
|
||||
sortOrder?: number | null;
|
||||
enabled?: boolean | null;
|
||||
};
|
||||
|
||||
type WorkspaceByokKeyInput = Omit<WorkspaceByokKey, 'apiKey'> & {
|
||||
apiKey?: string | null;
|
||||
};
|
||||
|
||||
function assertSupported() {
|
||||
if (!safeStorage.isEncryptionAvailable()) {
|
||||
throw new Error('Secure BYOK key storage is not available.');
|
||||
}
|
||||
}
|
||||
|
||||
function hasOwnField(
|
||||
key: WorkspaceByokKeyInput,
|
||||
field: keyof WorkspaceByokKey
|
||||
) {
|
||||
return Object.prototype.hasOwnProperty.call(key, field);
|
||||
}
|
||||
|
||||
function normalizeKey(
|
||||
key: WorkspaceByokKeyInput,
|
||||
existing?: WorkspaceByokKey,
|
||||
defaultSortOrder = 0
|
||||
): WorkspaceByokKey {
|
||||
if (!allowedProviders.has(key.provider)) {
|
||||
throw new Error('Unsupported BYOK provider.');
|
||||
}
|
||||
const apiKey = key.apiKey ?? existing?.apiKey;
|
||||
if (!key.id || !key.name || !apiKey) {
|
||||
throw new Error('Invalid BYOK key.');
|
||||
}
|
||||
return {
|
||||
id: key.id,
|
||||
provider: key.provider,
|
||||
name: key.name,
|
||||
description: hasOwnField(key, 'description')
|
||||
? (key.description ?? null)
|
||||
: (existing?.description ?? null),
|
||||
apiKey,
|
||||
endpoint: hasOwnField(key, 'endpoint')
|
||||
? (key.endpoint ?? null)
|
||||
: (existing?.endpoint ?? null),
|
||||
sortOrder: hasOwnField(key, 'sortOrder')
|
||||
? (key.sortOrder ?? defaultSortOrder)
|
||||
: (existing?.sortOrder ?? defaultSortOrder),
|
||||
enabled: hasOwnField(key, 'enabled')
|
||||
? (key.enabled ?? true)
|
||||
: (existing?.enabled ?? true),
|
||||
};
|
||||
}
|
||||
|
||||
function encryptKey(key: WorkspaceByokKey) {
|
||||
return safeStorage
|
||||
.encryptString(JSON.stringify(normalizeKey(key)))
|
||||
.toString('base64');
|
||||
}
|
||||
|
||||
function decryptKey(value: string): WorkspaceByokKey | null {
|
||||
try {
|
||||
return normalizeKey(
|
||||
JSON.parse(safeStorage.decryptString(Buffer.from(value, 'base64')))
|
||||
);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function sortWorkspaceKeys(keys: WorkspaceByokKey[]) {
|
||||
return keys.toSorted((a, b) => (a.sortOrder ?? 0) - (b.sortOrder ?? 0));
|
||||
}
|
||||
|
||||
function readWorkspaceKeys(workspaceId: string): WorkspaceByokKey[] {
|
||||
assertSupported();
|
||||
const encryptedKeys = byokStorage.get<string[]>(workspaceId) ?? [];
|
||||
return sortWorkspaceKeys(
|
||||
encryptedKeys.flatMap(value => {
|
||||
const key = decryptKey(value);
|
||||
return key ? [key] : [];
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
function writeWorkspaceKeys(workspaceId: string, keys: WorkspaceByokKey[]) {
|
||||
assertSupported();
|
||||
byokStorage.set(workspaceId, keys.map(encryptKey));
|
||||
}
|
||||
|
||||
function toPublicKey({ apiKey: _, ...key }: WorkspaceByokKey) {
|
||||
return {
|
||||
...key,
|
||||
storage: 'local',
|
||||
configured: true,
|
||||
endpointEditable: false,
|
||||
testStatus: 'passed',
|
||||
};
|
||||
}
|
||||
|
||||
export const byokStorageHandlers = {
|
||||
isSupported: async () => safeStorage.isEncryptionAvailable(),
|
||||
listWorkspaceKeys: async (_e, workspaceId: string) => {
|
||||
return readWorkspaceKeys(workspaceId).map(toPublicKey);
|
||||
},
|
||||
getWorkspaceLeaseProviders: async (_e, workspaceId: string) => {
|
||||
return readWorkspaceKeys(workspaceId).filter(key => key.enabled !== false);
|
||||
},
|
||||
upsertWorkspaceKey: async (
|
||||
_e,
|
||||
workspaceId: string,
|
||||
key: WorkspaceByokKeyInput
|
||||
) => {
|
||||
const keys = readWorkspaceKeys(workspaceId);
|
||||
const index = keys.findIndex(storedKey => storedKey.id === key.id);
|
||||
const nextKey = normalizeKey(
|
||||
key,
|
||||
index === -1 ? undefined : keys[index],
|
||||
keys.length
|
||||
);
|
||||
if (index === -1) {
|
||||
keys.push(nextKey);
|
||||
} else {
|
||||
keys[index] = nextKey;
|
||||
}
|
||||
writeWorkspaceKeys(workspaceId, keys);
|
||||
return toPublicKey(nextKey);
|
||||
},
|
||||
deleteWorkspaceKey: async (_e, workspaceId: string, keyId: string) => {
|
||||
writeWorkspaceKeys(
|
||||
workspaceId,
|
||||
readWorkspaceKeys(workspaceId).filter(key => key.id !== keyId)
|
||||
);
|
||||
return true;
|
||||
},
|
||||
reorderWorkspaceKeys: async (_e, workspaceId: string, ids: string[]) => {
|
||||
const keys = readWorkspaceKeys(workspaceId);
|
||||
const byId = new Map(keys.map(key => [key.id, key]));
|
||||
const ordered = ids
|
||||
.map((id, sortOrder) => {
|
||||
const key = byId.get(id);
|
||||
byId.delete(id);
|
||||
return key ? ({ ...key, sortOrder } as WorkspaceByokKey) : null;
|
||||
})
|
||||
.filter((key): key is WorkspaceByokKey => !!key);
|
||||
const nextKeys = sortWorkspaceKeys([
|
||||
...ordered,
|
||||
...Array.from(byId.values()).map((key, index) => ({
|
||||
...key,
|
||||
sortOrder: ordered.length + index,
|
||||
})),
|
||||
]);
|
||||
writeWorkspaceKeys(workspaceId, nextKeys);
|
||||
return nextKeys.map(toPublicKey);
|
||||
},
|
||||
clearWorkspaceKeys: async (_e, workspaceId: string) => {
|
||||
byokStorage.del(workspaceId);
|
||||
return true;
|
||||
},
|
||||
} satisfies NamespaceHandlers;
|
||||
@@ -2,6 +2,7 @@ import { I18n } from '@affine/i18n';
|
||||
import { ipcMain } from 'electron';
|
||||
|
||||
import { AFFINE_API_CHANNEL_NAME } from '../shared/type';
|
||||
import { byokStorageHandlers } from './byok-storage/handlers';
|
||||
import { clipboardHandlers } from './clipboard';
|
||||
import { configStorageHandlers } from './config-storage';
|
||||
import { findInPageHandlers } from './find-in-page';
|
||||
@@ -42,6 +43,7 @@ export const allHandlers = {
|
||||
recording: recordingHandlers,
|
||||
popup: popupHandlers,
|
||||
i18n: i18nHandlers,
|
||||
byokStorage: byokStorageHandlers,
|
||||
};
|
||||
|
||||
export const registerHandlers = () => {
|
||||
|
||||
151
packages/frontend/apps/electron/test/main/byok-storage.spec.ts
Normal file
151
packages/frontend/apps/electron/test/main/byok-storage.spec.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import path from 'node:path';
|
||||
|
||||
import fs from 'fs-extra';
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
|
||||
const tmpDir = path.join(__dirname, 'tmp-byok-storage');
|
||||
let disposeWorkspaceByokStorage: (() => void) | undefined;
|
||||
|
||||
vi.mock('electron', () => ({
|
||||
app: {
|
||||
getPath: () => tmpDir,
|
||||
on: vi.fn(),
|
||||
},
|
||||
safeStorage: {
|
||||
isEncryptionAvailable: () => true,
|
||||
encryptString: (value: string) => Buffer.from(value, 'utf-8'),
|
||||
decryptString: (value: Buffer) => value.toString('utf-8'),
|
||||
},
|
||||
}));
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetModules();
|
||||
disposeWorkspaceByokStorage = undefined;
|
||||
await fs.remove(tmpDir);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
disposeWorkspaceByokStorage?.();
|
||||
vi.resetModules();
|
||||
await fs.remove(tmpDir);
|
||||
});
|
||||
|
||||
describe('byok storage handlers', () => {
|
||||
test('stores encrypted local keys and keeps lease providers sorted', async () => {
|
||||
const { byokStorageHandlers, disposeWorkspaceByokStorage: dispose } =
|
||||
await import('@affine/electron/main/byok-storage/handlers');
|
||||
disposeWorkspaceByokStorage = dispose;
|
||||
const ipcEvent = undefined;
|
||||
|
||||
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
|
||||
id: 'local-openai',
|
||||
provider: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: 'sk-openai',
|
||||
sortOrder: 1,
|
||||
});
|
||||
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
|
||||
id: 'local-gemini',
|
||||
provider: 'gemini',
|
||||
name: 'Gemini',
|
||||
apiKey: 'sk-gemini',
|
||||
sortOrder: 0,
|
||||
});
|
||||
|
||||
const list = await byokStorageHandlers.listWorkspaceKeys(
|
||||
ipcEvent,
|
||||
'workspace-1'
|
||||
);
|
||||
expect(list.map(key => key.id)).toEqual(['local-gemini', 'local-openai']);
|
||||
expect(JSON.stringify(list)).not.toContain('sk-openai');
|
||||
|
||||
const reordered = await byokStorageHandlers.reorderWorkspaceKeys(
|
||||
ipcEvent,
|
||||
'workspace-1',
|
||||
['local-openai', 'local-gemini']
|
||||
);
|
||||
expect(reordered.map(key => key.id)).toEqual([
|
||||
'local-openai',
|
||||
'local-gemini',
|
||||
]);
|
||||
|
||||
const leaseProviders = await byokStorageHandlers.getWorkspaceLeaseProviders(
|
||||
ipcEvent,
|
||||
'workspace-1'
|
||||
);
|
||||
expect(leaseProviders.map(key => key.apiKey)).toEqual([
|
||||
'sk-openai',
|
||||
'sk-gemini',
|
||||
]);
|
||||
|
||||
await byokStorageHandlers.clearWorkspaceKeys(ipcEvent, 'workspace-1');
|
||||
await expect(
|
||||
byokStorageHandlers.listWorkspaceKeys(ipcEvent, 'workspace-1')
|
||||
).resolves.toEqual([]);
|
||||
});
|
||||
|
||||
test('preserves existing local key fields during partial updates', async () => {
|
||||
const { byokStorageHandlers, disposeWorkspaceByokStorage: dispose } =
|
||||
await import('@affine/electron/main/byok-storage/handlers');
|
||||
disposeWorkspaceByokStorage = dispose;
|
||||
const ipcEvent = undefined;
|
||||
|
||||
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
|
||||
id: 'local-openai',
|
||||
provider: 'openai',
|
||||
name: 'OpenAI',
|
||||
description: 'Primary key',
|
||||
apiKey: 'sk-openai',
|
||||
endpoint: 'https://api.openai.example/v1',
|
||||
sortOrder: 4,
|
||||
enabled: false,
|
||||
});
|
||||
|
||||
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
|
||||
id: 'local-openai',
|
||||
provider: 'openai',
|
||||
name: 'OpenAI renamed',
|
||||
apiKey: 'sk-openai-next',
|
||||
});
|
||||
|
||||
const [publicKey] = await byokStorageHandlers.listWorkspaceKeys(
|
||||
ipcEvent,
|
||||
'workspace-1'
|
||||
);
|
||||
expect(publicKey).toMatchObject({
|
||||
id: 'local-openai',
|
||||
name: 'OpenAI renamed',
|
||||
description: 'Primary key',
|
||||
endpoint: 'https://api.openai.example/v1',
|
||||
sortOrder: 4,
|
||||
enabled: false,
|
||||
});
|
||||
|
||||
const [leaseProvider] =
|
||||
await byokStorageHandlers.getWorkspaceLeaseProviders(
|
||||
ipcEvent,
|
||||
'workspace-1'
|
||||
);
|
||||
expect(leaseProvider).toBeUndefined();
|
||||
|
||||
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
|
||||
id: 'local-openai',
|
||||
provider: 'openai',
|
||||
name: 'OpenAI renamed again',
|
||||
enabled: true,
|
||||
});
|
||||
|
||||
const [enabledLeaseProvider] =
|
||||
await byokStorageHandlers.getWorkspaceLeaseProviders(
|
||||
ipcEvent,
|
||||
'workspace-1'
|
||||
);
|
||||
expect(enabledLeaseProvider).toMatchObject({
|
||||
name: 'OpenAI renamed again',
|
||||
apiKey: 'sk-openai-next',
|
||||
endpoint: 'https://api.openai.example/v1',
|
||||
sortOrder: 4,
|
||||
enabled: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -23,7 +23,8 @@ export default defineConfig({
|
||||
test: {
|
||||
setupFiles: [resolve(rootDir, './scripts/setup/global.ts')],
|
||||
include: ['./test/**/*.spec.ts'],
|
||||
testTimeout: 30000,
|
||||
testTimeout: 60000,
|
||||
hookTimeout: 30000,
|
||||
pool: 'forks',
|
||||
maxWorkers: 1,
|
||||
coverage: {
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
// @generated
|
||||
// This file was automatically generated and should not be edited.
|
||||
|
||||
@_exported import ApolloAPI
|
||||
|
||||
public class ApplyDocUpdatesMutation: GraphQLMutation {
|
||||
public static let operationName: String = "applyDocUpdates"
|
||||
public static let operationDocument: ApolloAPI.OperationDocument = .init(
|
||||
definition: .init(
|
||||
#"mutation applyDocUpdates($workspaceId: String!, $docId: String!, $op: String!, $updates: String!) { applyDocUpdates( workspaceId: $workspaceId docId: $docId op: $op updates: $updates ) }"#
|
||||
))
|
||||
|
||||
public var workspaceId: String
|
||||
public var docId: String
|
||||
public var op: String
|
||||
public var updates: String
|
||||
|
||||
public init(
|
||||
workspaceId: String,
|
||||
docId: String,
|
||||
op: String,
|
||||
updates: String
|
||||
) {
|
||||
self.workspaceId = workspaceId
|
||||
self.docId = docId
|
||||
self.op = op
|
||||
self.updates = updates
|
||||
}
|
||||
|
||||
public var __variables: Variables? { [
|
||||
"workspaceId": workspaceId,
|
||||
"docId": docId,
|
||||
"op": op,
|
||||
"updates": updates
|
||||
] }
|
||||
|
||||
public struct Data: AffineGraphQL.SelectionSet {
|
||||
public let __data: DataDict
|
||||
public init(_dataDict: DataDict) { __data = _dataDict }
|
||||
|
||||
public static var __parentType: any ApolloAPI.ParentType { AffineGraphQL.Objects.Mutation }
|
||||
public static var __selections: [ApolloAPI.Selection] { [
|
||||
.field("applyDocUpdates", String.self, arguments: [
|
||||
"workspaceId": .variable("workspaceId"),
|
||||
"docId": .variable("docId"),
|
||||
"op": .variable("op"),
|
||||
"updates": .variable("updates")
|
||||
]),
|
||||
] }
|
||||
public static var __fulfilledFragments: [any ApolloAPI.SelectionSet.Type] { [
|
||||
ApplyDocUpdatesMutation.Data.self
|
||||
] }
|
||||
|
||||
/// Apply updates to a doc using LLM and return the merged markdown.
|
||||
public var applyDocUpdates: String { __data["applyDocUpdates"] }
|
||||
}
|
||||
}
|
||||
@@ -354,12 +354,6 @@ declare global {
|
||||
files?: ContextMatchedFileChunk[];
|
||||
docs?: ContextMatchedDocChunk[];
|
||||
}>;
|
||||
applyDocUpdates: (
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
op: string,
|
||||
updates: string
|
||||
) => Promise<string>;
|
||||
addContextBlob: (options: {
|
||||
blobId: string;
|
||||
contextId: string;
|
||||
|
||||
@@ -2,7 +2,6 @@ import track from '@affine/track';
|
||||
import { WithDisposable } from '@blocksuite/affine/global/lit';
|
||||
import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
|
||||
import { type EditorHost, ShadowlessElement } from '@blocksuite/affine/std';
|
||||
import { LoadingIcon } from '@blocksuite/affine-components/icons';
|
||||
import type { NotificationService } from '@blocksuite/affine-shared/services';
|
||||
import {
|
||||
CloseIcon,
|
||||
@@ -17,8 +16,6 @@ import { css, html, nothing } from 'lit';
|
||||
import { property, state } from 'lit/decorators.js';
|
||||
import { repeat } from 'lit/directives/repeat.js';
|
||||
|
||||
import { AIProvider } from '../../provider';
|
||||
import { BlockDiffProvider } from '../../services/block-diff';
|
||||
import { diffMarkdown } from '../../utils/apply-model/markdown-diff';
|
||||
import { copyText } from '../../utils/editor-actions';
|
||||
import { AI_CHAT_AUTO_SCROLL_PAUSE_EVENT } from '../ai-chat-messages/auto-scroll';
|
||||
@@ -218,61 +215,21 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
|
||||
@state()
|
||||
accessor isCollapsed = false;
|
||||
|
||||
@state()
|
||||
accessor applyingMap: Record<string, boolean> = {};
|
||||
|
||||
@state()
|
||||
accessor acceptingMap: Record<string, boolean> = {};
|
||||
|
||||
get blockDiffService() {
|
||||
return this.host?.std.getOptional(BlockDiffProvider);
|
||||
}
|
||||
|
||||
get isBusy() {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
isBusyForOp(op: string) {
|
||||
return this.applyingMap[op] || this.acceptingMap[op];
|
||||
}
|
||||
|
||||
private async _handleApply(op: string, updates: string) {
|
||||
if (
|
||||
!this.host ||
|
||||
this.data.type !== 'tool-result' ||
|
||||
this.isBusyForOp(op)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
this.applyingMap = { ...this.applyingMap, [op]: true };
|
||||
try {
|
||||
const markdown = await AIProvider.context?.applyDocUpdates(
|
||||
this.host.std.workspace.id,
|
||||
this.data.args.doc_id,
|
||||
op,
|
||||
updates
|
||||
);
|
||||
if (!markdown) {
|
||||
private _handleApply(op: string) {
|
||||
if (!this.host || this.data.type !== 'tool-result') {
|
||||
return;
|
||||
}
|
||||
track.applyModel.chat.$.apply({
|
||||
instruction: this.data.args.instructions,
|
||||
operation: op,
|
||||
});
|
||||
await this.blockDiffService?.apply(this.host.store, markdown);
|
||||
} catch (error) {
|
||||
this.notificationService.notify({
|
||||
title: 'Failed to apply updates',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
accent: 'error',
|
||||
onClose: function (): void {},
|
||||
});
|
||||
} finally {
|
||||
this.applyingMap = { ...this.applyingMap, [op]: false };
|
||||
}
|
||||
}
|
||||
|
||||
private async _handleReject(op: string) {
|
||||
private _handleReject(op: string) {
|
||||
if (!this.host || this.data.type !== 'tool-result') {
|
||||
return;
|
||||
}
|
||||
@@ -281,45 +238,16 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
|
||||
instruction: this.data.args.instructions,
|
||||
operation: op,
|
||||
});
|
||||
this.blockDiffService?.setChangedMarkdown(null);
|
||||
this.blockDiffService?.rejectAll();
|
||||
}
|
||||
|
||||
private async _handleAccept(op: string, updates: string) {
|
||||
if (
|
||||
!this.host ||
|
||||
this.data.type !== 'tool-result' ||
|
||||
this.isBusyForOp(op)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
this.acceptingMap = { ...this.acceptingMap, [op]: true };
|
||||
try {
|
||||
const changedMarkdown = await AIProvider.context?.applyDocUpdates(
|
||||
this.host.std.workspace.id,
|
||||
this.data.args.doc_id,
|
||||
op,
|
||||
updates
|
||||
);
|
||||
if (!changedMarkdown) {
|
||||
private _handleAccept(op: string) {
|
||||
if (!this.host || this.data.type !== 'tool-result') {
|
||||
return;
|
||||
}
|
||||
track.applyModel.chat.$.accept({
|
||||
instruction: this.data.args.instructions,
|
||||
operation: op,
|
||||
});
|
||||
await this.blockDiffService?.apply(this.host.store, changedMarkdown);
|
||||
await this.blockDiffService?.acceptAll(this.host.store);
|
||||
} catch (error) {
|
||||
this.notificationService.notify({
|
||||
title: 'Failed to apply updates',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
accent: 'error',
|
||||
onClose: function (): void {},
|
||||
});
|
||||
} finally {
|
||||
this.acceptingMap = { ...this.acceptingMap, [op]: false };
|
||||
}
|
||||
}
|
||||
|
||||
private async _toggleCollapse() {
|
||||
@@ -421,7 +349,7 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
|
||||
return repeat(
|
||||
result.result,
|
||||
change => change.op,
|
||||
({ op, updates, originalContent, changedContent }) => {
|
||||
({ op, originalContent, changedContent }) => {
|
||||
const diffs = diffMarkdown(originalContent, changedContent);
|
||||
return html`
|
||||
<div class="doc-edit-tool-result-wrapper">
|
||||
@@ -449,14 +377,7 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
|
||||
${CopyIcon()}
|
||||
<affine-tooltip>Copy</affine-tooltip>
|
||||
</button>
|
||||
<button
|
||||
@click=${() => this._handleApply(op, updates)}
|
||||
?disabled=${this.isBusyForOp(op)}
|
||||
>
|
||||
${this.applyingMap[op]
|
||||
? html`${LoadingIcon()} Applying`
|
||||
: 'Apply'}
|
||||
</button>
|
||||
<button @click=${() => this._handleApply(op)}>Apply</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="doc-edit-tool-result-card-content">
|
||||
@@ -473,18 +394,12 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
|
||||
</button>
|
||||
<button
|
||||
class="doc-edit-tool-result-accept"
|
||||
@click=${() => this._handleAccept(op, updates)}
|
||||
?disabled=${this.isBusyForOp(op)}
|
||||
style="${this.isBusyForOp(op)
|
||||
? 'pointer-events: none; opacity: 0.6;'
|
||||
: ''}"
|
||||
@click=${() => this._handleAccept(op)}
|
||||
>
|
||||
${this.acceptingMap[op]
|
||||
? html`${LoadingIcon()}`
|
||||
: DoneIcon({
|
||||
${DoneIcon({
|
||||
style: `color: ${unsafeCSSVarV2('icon/activated')}`,
|
||||
})}
|
||||
${this.acceptingMap[op] ? 'Accepting...' : 'Accept'}
|
||||
Accept
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
addContextCategoryMutation,
|
||||
addContextDocMutation,
|
||||
addContextFileMutation,
|
||||
applyDocUpdatesMutation,
|
||||
cleanupCopilotSessionMutation,
|
||||
createCopilotContextMutation,
|
||||
createCopilotMessageMutation,
|
||||
@@ -473,6 +472,7 @@ export class CopilotClient {
|
||||
actionVersion,
|
||||
runId,
|
||||
retry,
|
||||
byokLeaseId,
|
||||
}: {
|
||||
sessionId: string;
|
||||
messageId?: string;
|
||||
@@ -483,6 +483,7 @@ export class CopilotClient {
|
||||
actionVersion?: string;
|
||||
runId?: string;
|
||||
retry?: boolean;
|
||||
byokLeaseId?: string;
|
||||
},
|
||||
endpoint = Endpoint.StreamObject
|
||||
) {
|
||||
@@ -499,6 +500,7 @@ export class CopilotClient {
|
||||
actionVersion,
|
||||
runId,
|
||||
retry,
|
||||
byokLeaseId,
|
||||
});
|
||||
if (queryString) {
|
||||
url += `?${queryString}`;
|
||||
@@ -511,12 +513,14 @@ export class CopilotClient {
|
||||
sessionId: string,
|
||||
messageId?: string,
|
||||
seed?: string,
|
||||
endpoint = Endpoint.Images
|
||||
endpoint = Endpoint.Images,
|
||||
byokLeaseId?: string
|
||||
) {
|
||||
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
|
||||
const queryString = this.paramsToQueryString({
|
||||
messageId,
|
||||
seed,
|
||||
byokLeaseId,
|
||||
});
|
||||
if (queryString) {
|
||||
url += `?${queryString}`;
|
||||
@@ -549,23 +553,6 @@ export class CopilotClient {
|
||||
}).then(res => res.queryWorkspaceEmbeddingStatus);
|
||||
}
|
||||
|
||||
applyDocUpdates(
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
op: string,
|
||||
updates: string
|
||||
) {
|
||||
return this.gql({
|
||||
query: applyDocUpdatesMutation,
|
||||
variables: {
|
||||
workspaceId,
|
||||
docId,
|
||||
op,
|
||||
updates,
|
||||
},
|
||||
}).then(res => res.applyDocUpdates);
|
||||
}
|
||||
|
||||
addContextBlob(options: OptionsField<typeof addContextBlobMutation>) {
|
||||
return this.gql({
|
||||
query: addContextBlobMutation,
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
/**
|
||||
* @vitest-environment happy-dom
|
||||
*/
|
||||
import { UserFriendlyError } from '@affine/error';
|
||||
import { beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
|
||||
import { type CopilotClient, Endpoint } from './copilot-client';
|
||||
import { textToText, toImage } from './request';
|
||||
|
||||
const electronApis = vi.hoisted(() => ({
|
||||
byokStorage: undefined as
|
||||
| {
|
||||
isSupported: () => Promise<boolean>;
|
||||
getWorkspaceLeaseProviders: (workspaceId: string) => Promise<
|
||||
Array<{
|
||||
provider: string;
|
||||
name: string;
|
||||
apiKey: string;
|
||||
description?: string | null;
|
||||
endpoint?: string | null;
|
||||
sortOrder?: number | null;
|
||||
enabled?: boolean | null;
|
||||
}>
|
||||
>;
|
||||
}
|
||||
| undefined,
|
||||
}));
|
||||
|
||||
const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() =>
|
||||
Symbol('createWorkspaceByokLocalLeaseMutation')
|
||||
);
|
||||
|
||||
vi.mock('@affine/electron-api', () => ({
|
||||
apis: electronApis,
|
||||
}));
|
||||
|
||||
vi.mock('@affine/graphql', () => ({
|
||||
ByokProvider: {
|
||||
openai: 'openai',
|
||||
anthropic: 'anthropic',
|
||||
gemini: 'gemini',
|
||||
fal: 'fal',
|
||||
},
|
||||
createWorkspaceByokLocalLeaseMutation,
|
||||
}));
|
||||
|
||||
function createClient(
|
||||
overrides: Partial<
|
||||
Pick<
|
||||
CopilotClient,
|
||||
'gql' | 'createMessage' | 'chatTextStream' | 'imagesStream'
|
||||
>
|
||||
> = {}
|
||||
) {
|
||||
return {
|
||||
gql: vi.fn().mockResolvedValue({
|
||||
createWorkspaceByokLocalLease: { leaseId: 'lease-1' },
|
||||
}),
|
||||
createMessage: vi.fn().mockResolvedValue('message-1'),
|
||||
chatTextStream: vi.fn(),
|
||||
imagesStream: vi.fn(),
|
||||
...overrides,
|
||||
} as unknown as CopilotClient;
|
||||
}
|
||||
|
||||
async function drain(stream: AsyncIterable<unknown>) {
|
||||
for await (const chunk of stream) {
|
||||
void chunk;
|
||||
}
|
||||
}
|
||||
|
||||
describe('AI request BYOK local lease handling', () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('BUILD_CONFIG', { isElectron: true });
|
||||
electronApis.byokStorage = {
|
||||
isSupported: vi.fn().mockResolvedValue(true),
|
||||
getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([
|
||||
{
|
||||
provider: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: 'sk-local',
|
||||
},
|
||||
]),
|
||||
};
|
||||
});
|
||||
|
||||
test('fails closed when local BYOK providers exist but lease creation fails', async () => {
|
||||
const client = createClient({
|
||||
gql: vi.fn().mockRejectedValue(new Error('mutation failed')),
|
||||
});
|
||||
|
||||
const result = textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
}) as Promise<string>;
|
||||
|
||||
await expect(result).rejects.toThrow('mutation failed');
|
||||
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('wraps local BYOK storage support failures as user friendly errors', async () => {
|
||||
electronApis.byokStorage = {
|
||||
isSupported: vi.fn().mockRejectedValue(new Error('support check failed')),
|
||||
getWorkspaceLeaseProviders: vi.fn(),
|
||||
};
|
||||
const client = createClient();
|
||||
|
||||
const result = textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
}) as Promise<string>;
|
||||
|
||||
await expect(result).rejects.toThrow('support check failed');
|
||||
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('wraps local BYOK provider loading failures as user friendly errors', async () => {
|
||||
electronApis.byokStorage = {
|
||||
isSupported: vi.fn().mockResolvedValue(true),
|
||||
getWorkspaceLeaseProviders: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('provider load failed')),
|
||||
};
|
||||
const client = createClient();
|
||||
|
||||
const result = textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
}) as Promise<string>;
|
||||
|
||||
await expect(result).rejects.toThrow('provider load failed');
|
||||
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create local BYOK lease after cancellation', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
createMessage: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return 'message-1';
|
||||
}),
|
||||
});
|
||||
|
||||
await expect(
|
||||
textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
signal: controller.signal,
|
||||
}) as Promise<string>
|
||||
).resolves.toBe('');
|
||||
expect(client.gql).not.toHaveBeenCalled();
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create stream local BYOK lease after cancellation', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
createMessage: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return 'message-1';
|
||||
}),
|
||||
});
|
||||
|
||||
await drain(
|
||||
textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
stream: true,
|
||||
signal: controller.signal,
|
||||
}) as AsyncIterable<string>
|
||||
);
|
||||
|
||||
expect(client.gql).not.toHaveBeenCalled();
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create text stream when cancelled while creating local BYOK lease', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
gql: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
|
||||
}),
|
||||
});
|
||||
|
||||
await drain(
|
||||
textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
stream: true,
|
||||
signal: controller.signal,
|
||||
}) as AsyncIterable<string>
|
||||
);
|
||||
|
||||
expect(client.gql).toHaveBeenCalled();
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create text request when cancelled while creating local BYOK lease', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
gql: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
|
||||
}),
|
||||
});
|
||||
|
||||
await expect(
|
||||
textToText({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'hello',
|
||||
signal: controller.signal,
|
||||
}) as Promise<string>
|
||||
).resolves.toBe('');
|
||||
|
||||
expect(client.gql).toHaveBeenCalled();
|
||||
expect(client.chatTextStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create image local BYOK lease after cancellation', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
createMessage: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return 'message-1';
|
||||
}),
|
||||
});
|
||||
|
||||
await drain(
|
||||
toImage({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'image',
|
||||
endpoint: Endpoint.Images,
|
||||
signal: controller.signal,
|
||||
}) as AsyncIterable<string>
|
||||
);
|
||||
|
||||
expect(client.gql).not.toHaveBeenCalled();
|
||||
expect(client.imagesStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('does not create image stream when cancelled while creating local BYOK lease', async () => {
|
||||
const controller = new AbortController();
|
||||
const client = createClient({
|
||||
gql: vi.fn().mockImplementation(async () => {
|
||||
controller.abort();
|
||||
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
|
||||
}),
|
||||
});
|
||||
|
||||
await drain(
|
||||
toImage({
|
||||
client,
|
||||
sessionId: 'session-1',
|
||||
workspaceId: 'workspace-1',
|
||||
content: 'image',
|
||||
endpoint: Endpoint.Images,
|
||||
signal: controller.signal,
|
||||
}) as AsyncIterable<string>
|
||||
);
|
||||
|
||||
expect(client.gql).toHaveBeenCalled();
|
||||
expect(client.imagesStream).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,10 @@
|
||||
import type { AIToolsConfig } from '@affine/core/modules/ai-button';
|
||||
import { apis, type ClientHandler } from '@affine/electron-api';
|
||||
import { UserFriendlyError } from '@affine/error';
|
||||
import {
|
||||
ByokProvider,
|
||||
createWorkspaceByokLocalLeaseMutation,
|
||||
} from '@affine/graphql';
|
||||
import { partition } from 'lodash-es';
|
||||
|
||||
import { AIProvider } from './ai-provider';
|
||||
@@ -7,9 +13,99 @@ import { toTextStream } from './event-source';
|
||||
|
||||
const TIMEOUT = 50000;
|
||||
|
||||
function isElectronBuild() {
|
||||
return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron;
|
||||
}
|
||||
|
||||
function byokStorageApi(): ClientHandler['byokStorage'] | undefined {
|
||||
return isElectronBuild() ? apis?.byokStorage : undefined;
|
||||
}
|
||||
|
||||
function toGraphqlByokProvider(provider: string): ByokProvider | null {
|
||||
switch (provider) {
|
||||
case ByokProvider.openai:
|
||||
return ByokProvider.openai;
|
||||
case ByokProvider.anthropic:
|
||||
return ByokProvider.anthropic;
|
||||
case ByokProvider.gemini:
|
||||
return ByokProvider.gemini;
|
||||
case ByokProvider.fal:
|
||||
return ByokProvider.fal;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function errorMetadata(error: unknown) {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return { kind: typeof error };
|
||||
}
|
||||
const record = error as Record<string, unknown>;
|
||||
return {
|
||||
name: typeof record.name === 'string' ? record.name : undefined,
|
||||
code: typeof record.code === 'string' ? record.code : undefined,
|
||||
status:
|
||||
typeof record.status === 'number' || typeof record.status === 'string'
|
||||
? record.status
|
||||
: undefined,
|
||||
type: typeof record.type === 'string' ? record.type : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
async function createWorkspaceByokLocalLease(
|
||||
client: CopilotClient,
|
||||
workspaceId?: string
|
||||
) {
|
||||
const storage = byokStorageApi();
|
||||
if (!workspaceId || !storage) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
if (!(await storage.isSupported())) return undefined;
|
||||
const providers = await storage.getWorkspaceLeaseProviders(workspaceId);
|
||||
if (!providers.length) return undefined;
|
||||
const leaseProviders = providers.flatMap(provider => {
|
||||
const gqlProvider = toGraphqlByokProvider(provider.provider);
|
||||
return gqlProvider
|
||||
? [
|
||||
{
|
||||
provider: gqlProvider,
|
||||
name: provider.name,
|
||||
description: provider.description ?? null,
|
||||
apiKey: provider.apiKey,
|
||||
endpoint: provider.endpoint ?? null,
|
||||
sortOrder: provider.sortOrder ?? 0,
|
||||
enabled: provider.enabled ?? true,
|
||||
},
|
||||
]
|
||||
: [];
|
||||
});
|
||||
if (!leaseProviders.length) return undefined;
|
||||
|
||||
const result = await client.gql({
|
||||
query: createWorkspaceByokLocalLeaseMutation,
|
||||
variables: {
|
||||
input: {
|
||||
workspaceId,
|
||||
providers: leaseProviders,
|
||||
},
|
||||
},
|
||||
});
|
||||
return result.createWorkspaceByokLocalLease.leaseId;
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to create workspace BYOK local lease',
|
||||
errorMetadata(error)
|
||||
);
|
||||
throw UserFriendlyError.fromAny(error);
|
||||
}
|
||||
}
|
||||
|
||||
export type TextToTextOptions = {
|
||||
client: CopilotClient;
|
||||
sessionId: string;
|
||||
workspaceId?: string;
|
||||
content?: string;
|
||||
attachments?: (string | Blob | File)[];
|
||||
params?: Record<string, any>;
|
||||
@@ -114,6 +210,7 @@ async function createMessage({
|
||||
export function textToText({
|
||||
client,
|
||||
sessionId,
|
||||
workspaceId,
|
||||
content,
|
||||
attachments,
|
||||
params,
|
||||
@@ -145,6 +242,16 @@ export function textToText({
|
||||
signal,
|
||||
});
|
||||
}
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
const byokLeaseId = await createWorkspaceByokLocalLease(
|
||||
client,
|
||||
workspaceId
|
||||
);
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
const eventSource = client.chatTextStream(
|
||||
{
|
||||
sessionId,
|
||||
@@ -156,6 +263,7 @@ export function textToText({
|
||||
actionVersion,
|
||||
runId,
|
||||
retry,
|
||||
byokLeaseId,
|
||||
},
|
||||
endpoint
|
||||
);
|
||||
@@ -203,6 +311,16 @@ export function textToText({
|
||||
signal,
|
||||
});
|
||||
}
|
||||
if (signal?.aborted) {
|
||||
return '';
|
||||
}
|
||||
const byokLeaseId = await createWorkspaceByokLocalLease(
|
||||
client,
|
||||
workspaceId
|
||||
);
|
||||
if (signal?.aborted) {
|
||||
return '';
|
||||
}
|
||||
const eventSource = client.chatTextStream(
|
||||
{
|
||||
sessionId,
|
||||
@@ -214,6 +332,7 @@ export function textToText({
|
||||
actionVersion,
|
||||
runId,
|
||||
retry,
|
||||
byokLeaseId,
|
||||
},
|
||||
endpoint
|
||||
);
|
||||
@@ -258,6 +377,7 @@ export function textToText({
|
||||
export function toImage({
|
||||
content,
|
||||
sessionId,
|
||||
workspaceId,
|
||||
attachments,
|
||||
params,
|
||||
seed,
|
||||
@@ -284,6 +404,16 @@ export function toImage({
|
||||
signal,
|
||||
});
|
||||
}
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
const byokLeaseId = await createWorkspaceByokLocalLease(
|
||||
client,
|
||||
workspaceId
|
||||
);
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
const eventSource =
|
||||
endpoint === Endpoint.Action
|
||||
? client.chatTextStream(
|
||||
@@ -294,10 +424,17 @@ export function toImage({
|
||||
actionVersion,
|
||||
runId,
|
||||
retry,
|
||||
byokLeaseId,
|
||||
},
|
||||
Endpoint.Action
|
||||
)
|
||||
: client.imagesStream(sessionId, messageId, seed, endpoint);
|
||||
: client.imagesStream(
|
||||
sessionId,
|
||||
messageId,
|
||||
seed,
|
||||
endpoint,
|
||||
byokLeaseId
|
||||
);
|
||||
AIProvider.LAST_ACTION_SESSIONID = sessionId;
|
||||
|
||||
for await (const event of toTextStream(eventSource, {
|
||||
|
||||
@@ -722,14 +722,6 @@ Could you make a new website based on these notes and send back just the html fi
|
||||
threshold
|
||||
);
|
||||
},
|
||||
applyDocUpdates: async (
|
||||
workspaceId: string,
|
||||
docId: string,
|
||||
op: string,
|
||||
updates: string
|
||||
) => {
|
||||
return client.applyDocUpdates(workspaceId, docId, op, updates);
|
||||
},
|
||||
addContextBlob: async (options: { blobId: string; contextId: string }) => {
|
||||
return client.addContextBlob({
|
||||
contextId: options.contextId,
|
||||
|
||||
@@ -0,0 +1,346 @@
|
||||
import { Button, Modal, notify } from '@affine/component';
|
||||
import {
|
||||
ByokKeyStorage,
|
||||
ByokProvider,
|
||||
testWorkspaceByokConfigMutation as testByokMutation,
|
||||
upsertWorkspaceByokConfigMutation as upsertByokMutation,
|
||||
} from '@affine/graphql';
|
||||
import { useI18n } from '@affine/i18n';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
|
||||
import { logByokError } from './errors';
|
||||
import * as styles from './index.css';
|
||||
import { readLocalKeys, upsertLocalKey } from './local-storage';
|
||||
import { byokT, providerLabels, storageLabel } from './metadata';
|
||||
import type {
|
||||
ByokKey,
|
||||
ByokSettings,
|
||||
ByokStorage,
|
||||
ByokTestResult,
|
||||
GqlFn,
|
||||
} from './types';
|
||||
|
||||
export const AddKeyModal = ({
|
||||
workspaceId,
|
||||
settings,
|
||||
editingKey,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSaved,
|
||||
localKeys,
|
||||
setLocalKeys,
|
||||
localStorageSupported,
|
||||
canAddServerKey,
|
||||
canAddLocalKey,
|
||||
gql,
|
||||
}: {
|
||||
workspaceId: string;
|
||||
settings: ByokSettings;
|
||||
editingKey: ByokKey | null;
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
onSaved: () => Promise<void>;
|
||||
localKeys: ByokKey[];
|
||||
setLocalKeys: (keys: ByokKey[]) => void;
|
||||
localStorageSupported: boolean;
|
||||
canAddServerKey: boolean;
|
||||
canAddLocalKey: boolean;
|
||||
gql?: GqlFn;
|
||||
}) => {
|
||||
const t = useI18n();
|
||||
const [provider, setProvider] = useState<ByokProvider>(ByokProvider.openai);
|
||||
const [name, setName] = useState('');
|
||||
const [description, setDescription] = useState('');
|
||||
const [storage, setStorage] = useState<ByokStorage>(ByokKeyStorage.server);
|
||||
const [apiKey, setApiKey] = useState('');
|
||||
const [endpoint, setEndpoint] = useState('');
|
||||
const [testResult, setTestResult] = useState<ByokTestResult | null>(null);
|
||||
const [testing, setTesting] = useState(false);
|
||||
const canTestStoredConfig =
|
||||
storage === ByokKeyStorage.server &&
|
||||
editingKey?.storage === ByokKeyStorage.server &&
|
||||
editingKey.provider === provider;
|
||||
const canTest = !!apiKey || canTestStoredConfig;
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
return;
|
||||
}
|
||||
setProvider(editingKey?.provider ?? ByokProvider.openai);
|
||||
setName(editingKey?.name ?? '');
|
||||
setDescription(editingKey?.description ?? '');
|
||||
setStorage(
|
||||
editingKey?.storage ??
|
||||
(canAddServerKey ? ByokKeyStorage.server : ByokKeyStorage.local)
|
||||
);
|
||||
setApiKey('');
|
||||
setEndpoint(editingKey?.endpoint ?? '');
|
||||
setTestResult(null);
|
||||
}, [canAddServerKey, editingKey, open]);
|
||||
|
||||
const testKey = useCallback(async () => {
|
||||
if (!gql) {
|
||||
return;
|
||||
}
|
||||
setTesting(true);
|
||||
try {
|
||||
const result = await gql({
|
||||
query: testByokMutation,
|
||||
variables: {
|
||||
input: {
|
||||
workspaceId,
|
||||
provider,
|
||||
storage,
|
||||
apiKey: apiKey || null,
|
||||
endpoint: endpoint || null,
|
||||
configId: canTestStoredConfig ? editingKey.id : null,
|
||||
},
|
||||
},
|
||||
});
|
||||
const nextResult = result.testWorkspaceByokConfig as
|
||||
| ByokTestResult
|
||||
| undefined;
|
||||
setTestResult(nextResult ?? null);
|
||||
if (nextResult && !nextResult.ok) {
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.test-failed.title'),
|
||||
message: nextResult.message,
|
||||
});
|
||||
}
|
||||
} finally {
|
||||
setTesting(false);
|
||||
}
|
||||
}, [
|
||||
apiKey,
|
||||
canTestStoredConfig,
|
||||
editingKey,
|
||||
endpoint,
|
||||
gql,
|
||||
provider,
|
||||
storage,
|
||||
t,
|
||||
workspaceId,
|
||||
]);
|
||||
|
||||
const save = useCallback(async () => {
|
||||
if (!testResult?.ok || !gql) {
|
||||
return;
|
||||
}
|
||||
if (storage === ByokKeyStorage.local) {
|
||||
const saved = await upsertLocalKey(workspaceId, {
|
||||
id:
|
||||
editingKey?.storage === ByokKeyStorage.local
|
||||
? editingKey.id
|
||||
: crypto.randomUUID(),
|
||||
provider,
|
||||
name,
|
||||
description,
|
||||
apiKey,
|
||||
endpoint: endpoint || null,
|
||||
sortOrder:
|
||||
editingKey?.storage === ByokKeyStorage.local
|
||||
? editingKey.sortOrder
|
||||
: localKeys.length,
|
||||
enabled: true,
|
||||
});
|
||||
if (!saved) {
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.local-save-failed.title'),
|
||||
message: byokT(t, 'notify.local-save-failed.message'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
setLocalKeys(await readLocalKeys(workspaceId));
|
||||
} else {
|
||||
await gql({
|
||||
query: upsertByokMutation,
|
||||
variables: {
|
||||
input: {
|
||||
workspaceId,
|
||||
id:
|
||||
editingKey?.storage === ByokKeyStorage.server
|
||||
? editingKey.id
|
||||
: null,
|
||||
provider,
|
||||
name,
|
||||
description,
|
||||
storage,
|
||||
apiKey: apiKey || null,
|
||||
endpoint: endpoint || null,
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
await onSaved();
|
||||
}
|
||||
onOpenChange(false);
|
||||
setApiKey('');
|
||||
setTestResult(null);
|
||||
}, [
|
||||
apiKey,
|
||||
description,
|
||||
editingKey,
|
||||
endpoint,
|
||||
gql,
|
||||
localKeys,
|
||||
name,
|
||||
onOpenChange,
|
||||
onSaved,
|
||||
provider,
|
||||
setLocalKeys,
|
||||
storage,
|
||||
t,
|
||||
testResult?.ok,
|
||||
workspaceId,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
width={520}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
title={
|
||||
editingKey ? byokT(t, 'modal.edit-title') : byokT(t, 'modal.add-title')
|
||||
}
|
||||
description={byokT(t, 'modal.description')}
|
||||
>
|
||||
<div className={styles.form}>
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.provider')}</span>
|
||||
<select
|
||||
className={styles.input}
|
||||
value={provider}
|
||||
onChange={event => {
|
||||
setProvider(event.target.value as ByokProvider);
|
||||
setTestResult(null);
|
||||
}}
|
||||
>
|
||||
{settings.allowedProviders.map(provider => (
|
||||
<option key={provider} value={provider}>
|
||||
{providerLabels[provider]}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.key-name')}</span>
|
||||
<input
|
||||
className={styles.input}
|
||||
value={name}
|
||||
onChange={event => setName(event.target.value)}
|
||||
placeholder={byokT(t, 'placeholder.key-name')}
|
||||
/>
|
||||
</label>
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.description')}</span>
|
||||
<input
|
||||
className={styles.input}
|
||||
value={description}
|
||||
onChange={event => setDescription(event.target.value)}
|
||||
placeholder={byokT(t, 'placeholder.description')}
|
||||
/>
|
||||
</label>
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.storage')}</span>
|
||||
<select
|
||||
className={styles.input}
|
||||
value={storage}
|
||||
disabled={!!editingKey}
|
||||
onChange={event => {
|
||||
setStorage(event.target.value as ByokStorage);
|
||||
setTestResult(null);
|
||||
}}
|
||||
>
|
||||
<option value={ByokKeyStorage.server} disabled={!canAddServerKey}>
|
||||
{storageLabel(t, ByokKeyStorage.server)}
|
||||
</option>
|
||||
<option
|
||||
value={ByokKeyStorage.local}
|
||||
disabled={!localStorageSupported || !canAddLocalKey}
|
||||
>
|
||||
{canAddLocalKey
|
||||
? byokT(t, 'storage.local-this-device')
|
||||
: byokT(t, 'storage.local-desktop-only')}
|
||||
</option>
|
||||
</select>
|
||||
</label>
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.api-key')}</span>
|
||||
<input
|
||||
className={styles.input}
|
||||
value={apiKey}
|
||||
onChange={event => {
|
||||
setApiKey(event.target.value);
|
||||
setTestResult(null);
|
||||
}}
|
||||
type="password"
|
||||
/>
|
||||
</label>
|
||||
{settings.customEndpointSupported ? (
|
||||
<label className={styles.field}>
|
||||
<span className={styles.label}>{byokT(t, 'field.endpoint')}</span>
|
||||
<input
|
||||
className={styles.input}
|
||||
value={endpoint}
|
||||
onChange={event => {
|
||||
setEndpoint(event.target.value);
|
||||
setTestResult(null);
|
||||
}}
|
||||
placeholder="https://api.example.com/v1"
|
||||
/>
|
||||
</label>
|
||||
) : null}
|
||||
<div className={styles.modalActions}>
|
||||
<span
|
||||
className={`${styles.testStatus} ${
|
||||
testResult?.ok
|
||||
? styles.success
|
||||
: testResult && !testResult.ok
|
||||
? styles.error
|
||||
: ''
|
||||
}`}
|
||||
>
|
||||
{testResult?.ok
|
||||
? byokT(t, 'status.key-verified')
|
||||
: testResult
|
||||
? byokT(t, 'status.key-test-failed')
|
||||
: ''}
|
||||
</span>
|
||||
<Button
|
||||
variant="secondary"
|
||||
disabled={!canTest || testing}
|
||||
onClick={() => {
|
||||
testKey().catch(error => {
|
||||
logByokError('Failed to test BYOK key', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.test-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}}
|
||||
>
|
||||
{byokT(t, 'action.test-key')}
|
||||
</Button>
|
||||
<Button variant="secondary" onClick={() => onOpenChange(false)}>
|
||||
{byokT(t, 'action.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
disabled={!testResult?.ok || !name}
|
||||
onClick={() => {
|
||||
save().catch(error => {
|
||||
logByokError('Failed to save BYOK key', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.save-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}}
|
||||
>
|
||||
{byokT(t, 'action.save-key')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,98 @@
|
||||
import { useI18n } from '@affine/i18n';
|
||||
import {
|
||||
ChatWithAiIcon,
|
||||
ImageIcon,
|
||||
PenIcon,
|
||||
TocIcon,
|
||||
TranscriptWithAiIcon,
|
||||
} from '@blocksuite/icons/rc';
|
||||
import type { ReactNode } from 'react';
|
||||
|
||||
import * as styles from './index.css';
|
||||
import { byokT, capabilityRows, warningDescription } from './metadata';
|
||||
import type { ByokKey, ByokSettings } from './types';
|
||||
|
||||
function coverageIcon(
|
||||
icon: (typeof capabilityRows)[number]['icon']
|
||||
): ReactNode {
|
||||
switch (icon) {
|
||||
case 'chat':
|
||||
return <ChatWithAiIcon className={styles.capabilityIconSvg} />;
|
||||
case 'action':
|
||||
return <PenIcon className={styles.capabilityIconSvg} />;
|
||||
case 'image':
|
||||
return <ImageIcon className={styles.capabilityIconSvg} />;
|
||||
case 'transcript':
|
||||
return <TranscriptWithAiIcon className={styles.capabilityIconSvg} />;
|
||||
case 'indexing':
|
||||
return <TocIcon className={styles.capabilityIconSvg} />;
|
||||
}
|
||||
}
|
||||
|
||||
function isRowCovered(row: (typeof capabilityRows)[number], keys: ByokKey[]) {
|
||||
if (!row.coverageCapabilities.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return keys.some(key => {
|
||||
if (!key.configured || !key.enabled) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
(!('storage' in row) || row.storage === key.storage) &&
|
||||
row.coverageCapabilities.every(capability =>
|
||||
key.capabilities.includes(capability)
|
||||
)
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
export const CoveragePanel = ({
|
||||
keys,
|
||||
settings,
|
||||
}: {
|
||||
keys: ByokKey[];
|
||||
settings: ByokSettings;
|
||||
}) => {
|
||||
const t = useI18n();
|
||||
|
||||
return (
|
||||
<div className={styles.panel}>
|
||||
<div className={styles.panelHeader}>
|
||||
<div className={styles.title}>{byokT(t, 'coverage.title')}</div>
|
||||
</div>
|
||||
<div className={styles.rows}>
|
||||
{capabilityRows.map(row => {
|
||||
const warning = settings.warnings.find(
|
||||
w => w.featureKind === row.featureKind
|
||||
);
|
||||
const covered = isRowCovered(row, keys);
|
||||
return (
|
||||
<div
|
||||
className={`${styles.row} ${styles.capabilityRow} ${
|
||||
covered ? '' : styles.capabilityRowInactive
|
||||
}`}
|
||||
data-covered={covered}
|
||||
data-testid={`workspace-byok-coverage-${row.featureKind}`}
|
||||
key={row.featureKind}
|
||||
>
|
||||
<div
|
||||
className={`${styles.capabilityIcon} ${
|
||||
covered ? styles.capabilityIconActive : ''
|
||||
}`}
|
||||
>
|
||||
{coverageIcon(row.icon)}
|
||||
</div>
|
||||
<div className={styles.rowMain}>
|
||||
<div className={styles.rowTitle}>{byokT(t, row.titleKey)}</div>
|
||||
<div className={styles.rowDescription}>
|
||||
{warningDescription(t, warning) ?? byokT(t, row.fallbackKey)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,19 @@
|
||||
function errorMetadata(error: unknown) {
|
||||
if (!error || typeof error !== 'object') {
|
||||
return { kind: typeof error };
|
||||
}
|
||||
const record = error as Record<string, unknown>;
|
||||
return {
|
||||
name: typeof record.name === 'string' ? record.name : undefined,
|
||||
code: typeof record.code === 'string' ? record.code : undefined,
|
||||
status:
|
||||
typeof record.status === 'number' || typeof record.status === 'string'
|
||||
? record.status
|
||||
: undefined,
|
||||
type: typeof record.type === 'string' ? record.type : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
export function logByokError(context: string, error: unknown) {
|
||||
console.warn(context, errorMetadata(error));
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
import { cssVar } from '@toeverything/theme';
|
||||
import { cssVarV2 } from '@toeverything/theme/v2';
|
||||
import { style } from '@vanilla-extract/css';
|
||||
|
||||
export const stack = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 24,
|
||||
});
|
||||
|
||||
export const panel = style({
|
||||
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
borderRadius: 8,
|
||||
overflow: 'hidden',
|
||||
background: cssVarV2('layer/background/primary'),
|
||||
});
|
||||
|
||||
export const panelHeader = style({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
gap: 12,
|
||||
padding: '12px 16px',
|
||||
borderBottom: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
});
|
||||
|
||||
export const title = style({
|
||||
fontSize: cssVar('fontSm'),
|
||||
fontWeight: 600,
|
||||
color: cssVarV2('text/primary'),
|
||||
});
|
||||
|
||||
export const description = style({
|
||||
fontSize: cssVar('fontXs'),
|
||||
lineHeight: '20px',
|
||||
color: cssVarV2('text/secondary'),
|
||||
});
|
||||
|
||||
export const empty = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
gap: 4,
|
||||
padding: '28px 20px',
|
||||
textAlign: 'center',
|
||||
});
|
||||
|
||||
export const rows = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
});
|
||||
|
||||
export const row = style({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: '24px 1fr auto',
|
||||
alignItems: 'center',
|
||||
gap: 12,
|
||||
padding: '12px 16px',
|
||||
borderBottom: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
selectors: {
|
||||
'&:last-child': {
|
||||
borderBottom: 0,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const capabilityRow = style({
|
||||
gridTemplateColumns: '32px 1fr',
|
||||
});
|
||||
|
||||
export const capabilityRowInactive = style({
|
||||
opacity: 0.48,
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
});
|
||||
|
||||
export const capabilityIcon = style({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
width: 24,
|
||||
height: 24,
|
||||
borderRadius: 6,
|
||||
color: cssVarV2('text/secondary'),
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
});
|
||||
|
||||
export const capabilityIconActive = style({
|
||||
color: cssVarV2('button/primary'),
|
||||
background: '#f0f7ff',
|
||||
});
|
||||
|
||||
export const capabilityIconSvg = style({
|
||||
width: 16,
|
||||
height: 16,
|
||||
});
|
||||
|
||||
export const rowDisabled = style({
|
||||
opacity: 0.55,
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
});
|
||||
|
||||
export const dragHandle = style({
|
||||
color: cssVarV2('text/secondary'),
|
||||
cursor: 'grab',
|
||||
textAlign: 'center',
|
||||
});
|
||||
|
||||
export const rowMain = style({
|
||||
minWidth: 0,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
});
|
||||
|
||||
export const rowTitle = style({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 8,
|
||||
minWidth: 0,
|
||||
fontSize: cssVar('fontSm'),
|
||||
fontWeight: 600,
|
||||
color: cssVarV2('text/primary'),
|
||||
});
|
||||
|
||||
export const rowDescription = style({
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
whiteSpace: 'nowrap',
|
||||
fontSize: cssVar('fontXs'),
|
||||
color: cssVarV2('text/secondary'),
|
||||
});
|
||||
|
||||
export const tags = style({
|
||||
display: 'flex',
|
||||
flexWrap: 'wrap',
|
||||
gap: 6,
|
||||
});
|
||||
|
||||
export const tag = style({
|
||||
borderRadius: 999,
|
||||
padding: '2px 8px',
|
||||
fontSize: 11,
|
||||
lineHeight: '16px',
|
||||
color: cssVarV2('text/secondary'),
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
});
|
||||
|
||||
export const dangerTag = style({
|
||||
color: '#b42318',
|
||||
background: '#fff5f5',
|
||||
});
|
||||
|
||||
export const rowActions = style({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 8,
|
||||
});
|
||||
|
||||
export const notice = style({
|
||||
borderRadius: 8,
|
||||
padding: 12,
|
||||
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
});
|
||||
|
||||
export const locked = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 16,
|
||||
padding: 24,
|
||||
borderRadius: 8,
|
||||
background: cssVarV2('layer/background/secondary'),
|
||||
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
});
|
||||
|
||||
export const form = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 12,
|
||||
});
|
||||
|
||||
export const field = style({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
height: '3em',
|
||||
});
|
||||
|
||||
export const label = style({
|
||||
fontSize: cssVar('fontXs'),
|
||||
color: cssVarV2('text/secondary'),
|
||||
});
|
||||
|
||||
export const input = style({
|
||||
height: 32,
|
||||
width: '100%',
|
||||
boxSizing: 'border-box',
|
||||
borderRadius: 8,
|
||||
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
|
||||
padding: '0 10px',
|
||||
fontSize: cssVar('fontSm'),
|
||||
lineHeight: '22px',
|
||||
background: cssVarV2('layer/background/primary'),
|
||||
color: cssVarV2('text/primary'),
|
||||
outline: 'none',
|
||||
selectors: {
|
||||
'&::placeholder': {
|
||||
color: cssVarV2('text/placeholder'),
|
||||
},
|
||||
'&:focus': {
|
||||
borderColor: cssVarV2('button/primary'),
|
||||
boxShadow: '0px 0px 0px 2px rgba(30, 150, 235, 0.30)',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const modalActions = style({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'flex-end',
|
||||
gap: 8,
|
||||
marginTop: 8,
|
||||
});
|
||||
|
||||
export const testStatus = style({
|
||||
marginRight: 'auto',
|
||||
fontSize: cssVar('fontXs'),
|
||||
});
|
||||
|
||||
export const error = style({
|
||||
color: '#b42318',
|
||||
});
|
||||
|
||||
export const success = style({
|
||||
color: '#168a58',
|
||||
});
|
||||
|
||||
export const chart = style({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(30, minmax(4px, 1fr))',
|
||||
alignItems: 'end',
|
||||
gap: 4,
|
||||
height: 140,
|
||||
padding: '16px 16px 24px',
|
||||
});
|
||||
|
||||
export const bar = style({
|
||||
minHeight: 2,
|
||||
borderRadius: '4px 4px 0 0',
|
||||
background: '#5b8def',
|
||||
});
|
||||
@@ -0,0 +1,700 @@
|
||||
/**
|
||||
* @vitest-environment happy-dom
|
||||
*/
|
||||
|
||||
import {
|
||||
cleanup,
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from '@testing-library/react';
|
||||
import type * as Infra from '@toeverything/infra';
|
||||
import type { ButtonHTMLAttributes, ReactNode } from 'react';
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
|
||||
|
||||
const gqlMock = vi.hoisted(() => vi.fn());
|
||||
const workspaceState = vi.hoisted(() => ({
|
||||
id: 'workspace-1',
|
||||
}));
|
||||
const electronApiState = vi.hoisted(() => ({
|
||||
apis: undefined as
|
||||
| {
|
||||
byokStorage?: {
|
||||
isSupported: () => Promise<boolean>;
|
||||
listWorkspaceKeys: (workspaceId: string) => Promise<unknown[]>;
|
||||
};
|
||||
}
|
||||
| undefined,
|
||||
}));
|
||||
const WorkspaceServerServiceToken = vi.hoisted(
|
||||
() => class WorkspaceServerService {}
|
||||
);
|
||||
const WorkspaceServiceToken = vi.hoisted(() => class WorkspaceService {});
|
||||
|
||||
const ByokProvider = vi.hoisted(() => ({
|
||||
openai: 'openai',
|
||||
anthropic: 'anthropic',
|
||||
gemini: 'gemini',
|
||||
fal: 'fal',
|
||||
}));
|
||||
const ByokKeyStorage = vi.hoisted(() => ({
|
||||
server: 'server',
|
||||
local: 'local',
|
||||
}));
|
||||
const ByokKeyTestStatus = vi.hoisted(() => ({
|
||||
untested: 'untested',
|
||||
passed: 'passed',
|
||||
failed: 'failed',
|
||||
}));
|
||||
|
||||
const workspaceByokSettingsQuery = vi.hoisted(() =>
|
||||
Symbol('workspaceByokSettingsQuery')
|
||||
);
|
||||
const testWorkspaceByokConfigMutation = vi.hoisted(() =>
|
||||
Symbol('testWorkspaceByokConfigMutation')
|
||||
);
|
||||
const upsertWorkspaceByokConfigMutation = vi.hoisted(() =>
|
||||
Symbol('upsertWorkspaceByokConfigMutation')
|
||||
);
|
||||
const clearWorkspaceByokConfigsMutation = vi.hoisted(() =>
|
||||
Symbol('clearWorkspaceByokConfigsMutation')
|
||||
);
|
||||
const deleteWorkspaceByokConfigMutation = vi.hoisted(() =>
|
||||
Symbol('deleteWorkspaceByokConfigMutation')
|
||||
);
|
||||
|
||||
vi.mock('@affine/component', () => ({
|
||||
Button: ({
|
||||
children,
|
||||
...props
|
||||
}: ButtonHTMLAttributes<HTMLButtonElement> & { children: ReactNode }) => (
|
||||
<button {...props}>{children}</button>
|
||||
),
|
||||
DragHandle: () => <span>drag-handle</span>,
|
||||
IconButton: ({ title, onClick }: { title: string; onClick?: () => void }) => (
|
||||
<button onClick={onClick}>{title}</button>
|
||||
),
|
||||
Modal: ({
|
||||
open,
|
||||
title,
|
||||
children,
|
||||
}: {
|
||||
open: boolean;
|
||||
title: string;
|
||||
children: ReactNode;
|
||||
}) =>
|
||||
open ? (
|
||||
<div role="dialog" aria-label={title}>
|
||||
{children}
|
||||
</div>
|
||||
) : null,
|
||||
notify: {
|
||||
error: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@affine/component/setting-components', () => ({
|
||||
SettingHeader: ({
|
||||
title,
|
||||
subtitle,
|
||||
}: {
|
||||
title: string;
|
||||
subtitle?: string;
|
||||
}) => (
|
||||
<header>
|
||||
<h1>{title}</h1>
|
||||
{subtitle ? <p>{subtitle}</p> : null}
|
||||
</header>
|
||||
),
|
||||
SettingWrapper: ({ children }: { children: ReactNode }) => (
|
||||
<main>{children}</main>
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock('@affine/core/modules/cloud', () => ({
|
||||
WorkspaceServerService: WorkspaceServerServiceToken,
|
||||
}));
|
||||
|
||||
vi.mock('@affine/core/modules/workspace', () => ({
|
||||
WorkspaceService: WorkspaceServiceToken,
|
||||
}));
|
||||
|
||||
vi.mock('@affine/electron-api', () => ({
|
||||
get apis() {
|
||||
return electronApiState.apis;
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@affine/graphql', () => ({
|
||||
ByokKeyStorage,
|
||||
ByokKeyTestStatus,
|
||||
ByokProvider,
|
||||
clearWorkspaceByokConfigsMutation,
|
||||
deleteWorkspaceByokConfigMutation,
|
||||
testWorkspaceByokConfigMutation,
|
||||
upsertWorkspaceByokConfigMutation,
|
||||
workspaceByokSettingsQuery,
|
||||
}));
|
||||
|
||||
vi.mock('@affine/i18n', () => {
|
||||
const messages: Record<string, string> = {
|
||||
'com.affine.settings.workspace.byok.action.add-key': 'Add key',
|
||||
'com.affine.settings.workspace.byok.action.edit': 'Edit',
|
||||
'com.affine.settings.workspace.byok.action.delete': 'Delete',
|
||||
'com.affine.settings.workspace.byok.action.test-key': 'Test key',
|
||||
'com.affine.settings.workspace.byok.action.save-key': 'Save key',
|
||||
'com.affine.settings.workspace.byok.action.cancel': 'Cancel',
|
||||
'com.affine.settings.workspace.byok.action.clear-all':
|
||||
'Clear all BYOK keys',
|
||||
'com.affine.settings.workspace.byok.field.api-key': 'API key',
|
||||
'com.affine.settings.workspace.byok.field.storage': 'Key storage',
|
||||
'com.affine.settings.workspace.byok.placeholder.key-name': 'Primary',
|
||||
'com.affine.settings.workspace.byok.status.key-verified': 'Key verified',
|
||||
'com.affine.settings.workspace.byok.status.disabled-after-failure':
|
||||
'Disabled after failure',
|
||||
'com.affine.settings.workspace.byok.storage.local': 'Local',
|
||||
'com.affine.settings.workspace.byok.storage.server': 'Server',
|
||||
'com.affine.settings.workspace.byok.storage.local-this-device':
|
||||
'Local (this device)',
|
||||
'com.affine.settings.workspace.byok.storage.local-desktop-only':
|
||||
'Local (Desktop only)',
|
||||
'com.affine.settings.workspace.byok.usage.tokens': '{{count}} tokens',
|
||||
'com.affine.settings.workspace.byok.notify.operation-failed.message':
|
||||
'Please try again.',
|
||||
'com.affine.settings.workspace.byok.notify.test-failed.title':
|
||||
'Key test failed',
|
||||
'com.affine.settings.workspace.byok.notify.load-failed.title':
|
||||
'BYOK settings not loaded',
|
||||
'com.affine.settings.workspace.byok.notify.save-failed.title':
|
||||
'BYOK key not saved',
|
||||
'com.affine.settings.workspace.byok.notify.delete-failed.title':
|
||||
'BYOK key not deleted',
|
||||
'com.affine.settings.workspace.byok.notify.reorder-failed.title':
|
||||
'BYOK keys not reordered',
|
||||
'com.affine.settings.workspace.byok.notify.clear-failed.title':
|
||||
'BYOK keys not cleared',
|
||||
};
|
||||
const translate = (key: string, options?: Record<string, unknown>) => {
|
||||
let message = messages[key] ?? key;
|
||||
for (const [name, value] of Object.entries(options ?? {})) {
|
||||
message = message.replaceAll(`{{${name}}}`, String(value));
|
||||
}
|
||||
return message;
|
||||
};
|
||||
const t = new Proxy(
|
||||
{
|
||||
t: translate,
|
||||
},
|
||||
{
|
||||
get(target, key: string) {
|
||||
if (key in target) {
|
||||
return target[key as keyof typeof target];
|
||||
}
|
||||
return (options?: Record<string, unknown>) => translate(key, options);
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
useI18n: () => t,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@blocksuite/icons/rc', () => ({
|
||||
ChatWithAiIcon: () => <span>chat-ai</span>,
|
||||
DeleteIcon: () => <span>delete</span>,
|
||||
EditIcon: () => <span>edit</span>,
|
||||
ImageIcon: () => <span>image</span>,
|
||||
PenIcon: () => <span>pen</span>,
|
||||
TocIcon: () => <span>toc</span>,
|
||||
TranscriptWithAiIcon: () => <span>transcript</span>,
|
||||
}));
|
||||
|
||||
vi.mock('@toeverything/infra', async importOriginal => {
|
||||
const actual = await importOriginal<typeof Infra>();
|
||||
|
||||
return {
|
||||
...actual,
|
||||
useService: (token: unknown) => {
|
||||
if (token === WorkspaceServerServiceToken) {
|
||||
return {
|
||||
server: {
|
||||
gql: gqlMock,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (token === WorkspaceServiceToken) {
|
||||
return {
|
||||
workspace: workspaceState,
|
||||
};
|
||||
}
|
||||
return {};
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
import { WorkspaceByokSetting } from '.';
|
||||
import { logByokError } from './errors';
|
||||
import { UsagePanel } from './usage';
|
||||
|
||||
function settings(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
workspaceId: 'workspace-1',
|
||||
entitled: true,
|
||||
serverEntitled: true,
|
||||
localEntitled: false,
|
||||
entitlementRequired: ['Pro', 'Team', 'Believer'],
|
||||
allowedProviders: ['openai', 'anthropic', 'gemini', 'fal'],
|
||||
localStorageSupported: false,
|
||||
customEndpointSupported: false,
|
||||
hasAiPlan: true,
|
||||
keys: [],
|
||||
warnings: [],
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function byokKey(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
id: 'server-key',
|
||||
provider: ByokProvider.openai,
|
||||
name: 'Primary',
|
||||
description: 'Workspace fallback key',
|
||||
storage: ByokKeyStorage.server,
|
||||
configured: true,
|
||||
enabled: true,
|
||||
endpoint: null,
|
||||
endpointEditable: false,
|
||||
sortOrder: 0,
|
||||
capabilities: ['Text', 'Image input', 'Actions', 'Image generate'],
|
||||
testStatus: ByokKeyTestStatus.passed,
|
||||
disabledReason: null,
|
||||
lastTestedAt: null,
|
||||
lastTestError: null,
|
||||
lastUsedAt: null,
|
||||
lastErrorAt: null,
|
||||
lastError: null,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function settingsResponse(overrides: Record<string, unknown> = {}) {
|
||||
return {
|
||||
workspace: {
|
||||
byokSettings: settings(overrides),
|
||||
byokUsage: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('WorkspaceByokSetting', () => {
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
gqlMock.mockReset();
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse();
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
vi.stubGlobal('BUILD_CONFIG', { isElectron: false });
|
||||
electronApiState.apis = undefined;
|
||||
});
|
||||
|
||||
test('renders locked state without key management controls', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
entitled: false,
|
||||
serverEntitled: false,
|
||||
localEntitled: false,
|
||||
});
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
await screen.findByTestId('workspace-byok-locked');
|
||||
expect(screen.queryByText('Add key')).toBeNull();
|
||||
expect(screen.queryByTestId('workspace-byok-empty')).toBeNull();
|
||||
});
|
||||
|
||||
test('renders empty state and keeps save disabled until key test passes', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse();
|
||||
}
|
||||
if (query === testWorkspaceByokConfigMutation) {
|
||||
return {
|
||||
testWorkspaceByokConfig: {
|
||||
ok: true,
|
||||
status: 'passed',
|
||||
message: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (query === upsertWorkspaceByokConfigMutation) {
|
||||
return { upsertWorkspaceByokConfig: { id: 'server-key' } };
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
await screen.findByTestId('workspace-byok-empty');
|
||||
fireEvent.click(screen.getAllByText('Add key')[0]);
|
||||
expect(screen.getByText<HTMLButtonElement>('Save key').disabled).toBe(true);
|
||||
|
||||
fireEvent.change(screen.getByPlaceholderText('Primary'), {
|
||||
target: { value: 'Primary' },
|
||||
});
|
||||
fireEvent.change(screen.getByLabelText('API key'), {
|
||||
target: { value: 'sk-test' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('Test key'));
|
||||
|
||||
await screen.findByText('Key verified');
|
||||
expect(screen.getByText<HTMLButtonElement>('Save key').disabled).toBe(
|
||||
false
|
||||
);
|
||||
fireEvent.click(screen.getByText('Save key'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(gqlMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
query: upsertWorkspaceByokConfigMutation,
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test('keeps local storage disabled on web even for local-entitled users', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
localEntitled: true,
|
||||
localStorageSupported: true,
|
||||
});
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
await screen.findByTestId('workspace-byok-empty');
|
||||
fireEvent.click(screen.getAllByText('Add key')[0]);
|
||||
|
||||
const storageSelect =
|
||||
screen.getByLabelText<HTMLSelectElement>('Key storage');
|
||||
const localOption = Array.from(storageSelect.options).find(
|
||||
option => option.value === ByokKeyStorage.local
|
||||
);
|
||||
expect(localOption?.disabled).toBe(true);
|
||||
});
|
||||
|
||||
test('reorders server keys within their storage bucket', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
keys: [
|
||||
byokKey({ id: 'server-1', name: 'First', sortOrder: 0 }),
|
||||
byokKey({ id: 'server-2', name: 'Second', sortOrder: 1 }),
|
||||
],
|
||||
});
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
const firstRow = (await screen.findByText('OpenAI / First')).closest(
|
||||
'[draggable="true"]'
|
||||
);
|
||||
const secondRow = screen
|
||||
.getByText('OpenAI / Second')
|
||||
.closest('[draggable="true"]');
|
||||
|
||||
expect(firstRow).not.toBeNull();
|
||||
expect(secondRow).not.toBeNull();
|
||||
fireEvent.dragStart(firstRow as Element);
|
||||
fireEvent.dragOver(secondRow as Element);
|
||||
fireEvent.drop(secondRow as Element);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(gqlMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
variables: expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
workspaceId: 'workspace-1',
|
||||
storage: ByokKeyStorage.server,
|
||||
ids: ['server-2', 'server-1'],
|
||||
}),
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test('marks coverage rows by configured provider support', async () => {
|
||||
let keys = [
|
||||
byokKey({ provider: ByokProvider.openai }),
|
||||
byokKey({
|
||||
id: 'disabled-gemini',
|
||||
provider: ByokProvider.gemini,
|
||||
enabled: false,
|
||||
capabilities: [
|
||||
'Text',
|
||||
'Image input',
|
||||
'Actions',
|
||||
'Image generate',
|
||||
'Transcript',
|
||||
'Indexing',
|
||||
],
|
||||
}),
|
||||
byokKey({
|
||||
id: 'local-gemini',
|
||||
provider: ByokProvider.gemini,
|
||||
storage: ByokKeyStorage.local,
|
||||
capabilities: ['Text', 'Image input', 'Actions', 'Image generate'],
|
||||
}),
|
||||
];
|
||||
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
keys,
|
||||
});
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
expect(
|
||||
(await screen.findByTestId('workspace-byok-coverage-chat')).dataset
|
||||
.covered
|
||||
).toBe('true');
|
||||
expect(
|
||||
screen.getByTestId('workspace-byok-coverage-action').dataset.covered
|
||||
).toBe('true');
|
||||
expect(
|
||||
screen.getByTestId('workspace-byok-coverage-image').dataset.covered
|
||||
).toBe('true');
|
||||
expect(
|
||||
screen.getByTestId('workspace-byok-coverage-transcript').dataset.covered
|
||||
).toBe('false');
|
||||
expect(
|
||||
screen.getByTestId('workspace-byok-coverage-workspace_indexing').dataset
|
||||
.covered
|
||||
).toBe('false');
|
||||
expect(screen.getAllByTestId(/^workspace-byok-coverage-/)).toHaveLength(5);
|
||||
|
||||
cleanup();
|
||||
keys = [
|
||||
byokKey({
|
||||
provider: ByokProvider.gemini,
|
||||
capabilities: [
|
||||
'Text',
|
||||
'Image input',
|
||||
'Actions',
|
||||
'Image generate',
|
||||
'Transcript',
|
||||
'Indexing',
|
||||
],
|
||||
}),
|
||||
];
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
expect(
|
||||
(await screen.findByTestId('workspace-byok-coverage-transcript')).dataset
|
||||
.covered
|
||||
).toBe('true');
|
||||
expect(
|
||||
screen.getByTestId('workspace-byok-coverage-workspace_indexing').dataset
|
||||
.covered
|
||||
).toBe('true');
|
||||
});
|
||||
|
||||
test('restores a failed server row after key test passes', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
keys: [
|
||||
byokKey({
|
||||
enabled: false,
|
||||
testStatus: ByokKeyTestStatus.failed,
|
||||
disabledReason: 'recent_failure',
|
||||
lastErrorAt: '2026-05-01T00:00:00.000Z',
|
||||
lastError: 'Provider rejected the API key.',
|
||||
}),
|
||||
],
|
||||
});
|
||||
}
|
||||
if (query === testWorkspaceByokConfigMutation) {
|
||||
return {
|
||||
testWorkspaceByokConfig: {
|
||||
ok: true,
|
||||
status: 'passed',
|
||||
message: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (query === upsertWorkspaceByokConfigMutation) {
|
||||
return {
|
||||
upsertWorkspaceByokConfig: {
|
||||
id: 'server-key',
|
||||
},
|
||||
};
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
await screen.findByText('Disabled after failure');
|
||||
fireEvent.click(screen.getByText('Edit'));
|
||||
fireEvent.change(screen.getByLabelText('API key'), {
|
||||
target: { value: 'sk-test' },
|
||||
});
|
||||
fireEvent.click(screen.getByText('Test key'));
|
||||
|
||||
await screen.findByText('Key verified');
|
||||
fireEvent.click(screen.getByText('Save key'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(gqlMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
query: upsertWorkspaceByokConfigMutation,
|
||||
variables: expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
id: 'server-key',
|
||||
enabled: true,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test('tests a saved server key without resending plaintext', async () => {
|
||||
gqlMock.mockImplementation(async ({ query }) => {
|
||||
if (query === workspaceByokSettingsQuery) {
|
||||
return settingsResponse({
|
||||
keys: [byokKey()],
|
||||
});
|
||||
}
|
||||
if (query === testWorkspaceByokConfigMutation) {
|
||||
return {
|
||||
testWorkspaceByokConfig: {
|
||||
ok: true,
|
||||
status: 'passed',
|
||||
message: null,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (query === upsertWorkspaceByokConfigMutation) {
|
||||
return {
|
||||
upsertWorkspaceByokConfig: {
|
||||
id: 'server-key',
|
||||
},
|
||||
};
|
||||
}
|
||||
throw new Error('Unexpected GraphQL operation');
|
||||
});
|
||||
|
||||
render(<WorkspaceByokSetting />);
|
||||
|
||||
await screen.findByText('OpenAI / Primary');
|
||||
fireEvent.click(screen.getByText('Edit'));
|
||||
expect(screen.getByText<HTMLButtonElement>('Test key').disabled).toBe(
|
||||
false
|
||||
);
|
||||
fireEvent.click(screen.getByText('Test key'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(gqlMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
query: testWorkspaceByokConfigMutation,
|
||||
variables: expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
apiKey: null,
|
||||
configId: 'server-key',
|
||||
}),
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
await screen.findByText('Key verified');
|
||||
fireEvent.click(screen.getByText('Save key'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(gqlMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
query: upsertWorkspaceByokConfigMutation,
|
||||
variables: expect.objectContaining({
|
||||
input: expect.objectContaining({
|
||||
apiKey: null,
|
||||
id: 'server-key',
|
||||
}),
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('UsagePanel', () => {
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
test('aggregates usage rows by date before rendering bars', () => {
|
||||
const today = new Date().toISOString();
|
||||
render(
|
||||
<UsagePanel
|
||||
keys={[]}
|
||||
usage={[
|
||||
{ date: today, featureKind: 'chat', totalTokens: 3 },
|
||||
{ date: today, featureKind: 'transcript', totalTokens: 5 },
|
||||
]}
|
||||
onClearAll={() => {}}
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByTitle('8 tokens')).not.toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('logByokError', () => {
|
||||
test('logs safe metadata without raw error message', () => {
|
||||
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
const error = Object.assign(
|
||||
new Error('authorization: Bearer token=a+b%2F=='),
|
||||
{
|
||||
code: 'BAD_REQUEST',
|
||||
status: 400,
|
||||
type: 'bad_request',
|
||||
}
|
||||
);
|
||||
|
||||
try {
|
||||
logByokError('byok', error);
|
||||
expect(warn).toHaveBeenCalledWith('byok', {
|
||||
name: 'Error',
|
||||
code: 'BAD_REQUEST',
|
||||
status: 400,
|
||||
type: 'bad_request',
|
||||
});
|
||||
expect(JSON.stringify(warn.mock.calls)).not.toContain('token=a+b%2F==');
|
||||
} finally {
|
||||
warn.mockRestore();
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,363 @@
|
||||
import { Button, notify } from '@affine/component';
|
||||
import {
|
||||
SettingHeader,
|
||||
SettingWrapper,
|
||||
} from '@affine/component/setting-components';
|
||||
import { WorkspaceServerService } from '@affine/core/modules/cloud';
|
||||
import { WorkspaceService } from '@affine/core/modules/workspace';
|
||||
import {
|
||||
ByokKeyStorage,
|
||||
clearWorkspaceByokConfigsMutation as clearByokMutation,
|
||||
deleteWorkspaceByokConfigMutation as deleteByokMutation,
|
||||
type GraphQLQuery,
|
||||
workspaceByokSettingsQuery as byokSettingsQuery,
|
||||
} from '@affine/graphql';
|
||||
import { useI18n } from '@affine/i18n';
|
||||
import { useService } from '@toeverything/infra';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { AddKeyModal } from './add-key-modal';
|
||||
import { CoveragePanel } from './coverage';
|
||||
import { logByokError } from './errors';
|
||||
import * as styles from './index.css';
|
||||
import { KeyList } from './key-list';
|
||||
import {
|
||||
clearLocalKeys,
|
||||
deleteLocalKey,
|
||||
localByokStorageSupported,
|
||||
readLocalKeys,
|
||||
reorderLocalKeys,
|
||||
} from './local-storage';
|
||||
import { byokT } from './metadata';
|
||||
import type {
|
||||
ByokKey,
|
||||
ByokSettings,
|
||||
ByokStorage,
|
||||
ByokUsagePoint,
|
||||
GqlFn,
|
||||
} from './types';
|
||||
import { UsagePanel } from './usage';
|
||||
|
||||
const reorderByokMutation = {
|
||||
id: 'reorderWorkspaceByokConfigsMutation',
|
||||
op: 'reorderWorkspaceByokConfigs',
|
||||
query: `mutation reorderWorkspaceByokConfigs($input: ReorderWorkspaceByokConfigsInput!) {
|
||||
reorderWorkspaceByokConfigs(input: $input) {
|
||||
id
|
||||
sortOrder
|
||||
}
|
||||
}`,
|
||||
} satisfies GraphQLQuery;
|
||||
|
||||
export const WorkspaceByokSetting = () => {
|
||||
const t = useI18n();
|
||||
const workspace = useService(WorkspaceService).workspace;
|
||||
const workspaceServer = useService(WorkspaceServerService);
|
||||
const [settings, setSettings] = useState<ByokSettings | null>(null);
|
||||
const [usage, setUsage] = useState<ByokUsagePoint[]>([]);
|
||||
const [localKeys, setLocalKeys] = useState<ByokKey[]>([]);
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [editingKey, setEditingKey] = useState<ByokKey | null>(null);
|
||||
const [draggingKey, setDraggingKey] = useState<{
|
||||
id: string;
|
||||
storage: ByokStorage;
|
||||
} | null>(null);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
if (!workspaceServer.server) {
|
||||
return;
|
||||
}
|
||||
const to = new Date();
|
||||
const from = new Date(to.getTime() - 30 * 24 * 60 * 60 * 1000);
|
||||
const gql = workspaceServer.server.gql as GqlFn;
|
||||
const data = await gql({
|
||||
query: byokSettingsQuery,
|
||||
variables: {
|
||||
id: workspace.id,
|
||||
from: from.toISOString(),
|
||||
to: to.toISOString(),
|
||||
},
|
||||
});
|
||||
const [localStorageSupported, nextLocalKeys] = await Promise.all([
|
||||
localByokStorageSupported(),
|
||||
readLocalKeys(workspace.id),
|
||||
]);
|
||||
setSettings({
|
||||
...data.workspace.byokSettings,
|
||||
localStorageSupported:
|
||||
data.workspace.byokSettings.localEntitled && localStorageSupported,
|
||||
});
|
||||
setUsage(data.workspace.byokUsage);
|
||||
setLocalKeys(nextLocalKeys);
|
||||
}, [workspace.id, workspaceServer.server]);
|
||||
|
||||
useEffect(() => {
|
||||
load().catch(error => {
|
||||
logByokError('Failed to load BYOK settings', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.load-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}, [load, t]);
|
||||
|
||||
const keys = useMemo(() => {
|
||||
return [...localKeys, ...(settings?.keys ?? [])].toSorted((a, b) => {
|
||||
if (a.storage !== b.storage) {
|
||||
return a.storage === ByokKeyStorage.local ? -1 : 1;
|
||||
}
|
||||
return a.sortOrder - b.sortOrder;
|
||||
});
|
||||
}, [localKeys, settings?.keys]);
|
||||
const canAddServerKey = settings?.serverEntitled ?? false;
|
||||
const canAddLocalKey =
|
||||
(settings?.localEntitled ?? false) &&
|
||||
(settings?.localStorageSupported ?? false);
|
||||
const canManageKeys = canAddServerKey || canAddLocalKey;
|
||||
|
||||
const clearAll = useCallback(async () => {
|
||||
if (!settings) {
|
||||
return;
|
||||
}
|
||||
if (!workspaceServer.server && settings.serverEntitled) {
|
||||
return;
|
||||
}
|
||||
if (settings.serverEntitled && workspaceServer.server) {
|
||||
const gql = workspaceServer.server.gql as GqlFn;
|
||||
await gql({
|
||||
query: clearByokMutation,
|
||||
variables: { workspaceId: workspace.id },
|
||||
});
|
||||
}
|
||||
if (settings.localStorageSupported) {
|
||||
await clearLocalKeys(workspace.id);
|
||||
}
|
||||
setLocalKeys([]);
|
||||
await load();
|
||||
}, [load, settings, workspace.id, workspaceServer.server]);
|
||||
|
||||
const deleteKey = useCallback(
|
||||
async (key: ByokKey) => {
|
||||
if (key.storage === ByokKeyStorage.local) {
|
||||
await deleteLocalKey(workspace.id, key.id);
|
||||
setLocalKeys(await readLocalKeys(workspace.id));
|
||||
return;
|
||||
}
|
||||
const gql = workspaceServer.server?.gql as
|
||||
| ((input: {
|
||||
query: GraphQLQuery;
|
||||
variables?: Record<string, unknown>;
|
||||
}) => Promise<unknown>)
|
||||
| undefined;
|
||||
await gql?.({
|
||||
query: deleteByokMutation,
|
||||
variables: { workspaceId: workspace.id, id: key.id },
|
||||
});
|
||||
await load();
|
||||
},
|
||||
[load, workspace.id, workspaceServer.server]
|
||||
);
|
||||
|
||||
const reorderKey = useCallback(
|
||||
async (targetKey: ByokKey) => {
|
||||
if (!draggingKey || draggingKey.id === targetKey.id) {
|
||||
return;
|
||||
}
|
||||
if (draggingKey.storage !== targetKey.storage) {
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.cross-storage-reorder.title'),
|
||||
message: byokT(t, 'notify.cross-storage-reorder.message'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const bucket = keys.filter(key => key.storage === targetKey.storage);
|
||||
const fromIndex = bucket.findIndex(key => key.id === draggingKey.id);
|
||||
const toIndex = bucket.findIndex(key => key.id === targetKey.id);
|
||||
if (fromIndex === -1 || toIndex === -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nextBucket = [...bucket];
|
||||
const [moved] = nextBucket.splice(fromIndex, 1);
|
||||
nextBucket.splice(toIndex, 0, moved);
|
||||
const nextBucketIds = nextBucket.map(key => key.id);
|
||||
|
||||
if (targetKey.storage === ByokKeyStorage.local) {
|
||||
setLocalKeys(await reorderLocalKeys(workspace.id, nextBucketIds));
|
||||
return;
|
||||
}
|
||||
|
||||
const gql = workspaceServer.server?.gql as
|
||||
| ((input: {
|
||||
query: GraphQLQuery;
|
||||
variables?: Record<string, unknown>;
|
||||
}) => Promise<unknown>)
|
||||
| undefined;
|
||||
await gql?.({
|
||||
query: reorderByokMutation,
|
||||
variables: {
|
||||
input: {
|
||||
workspaceId: workspace.id,
|
||||
storage: ByokKeyStorage.server,
|
||||
ids: nextBucketIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
await load();
|
||||
},
|
||||
[draggingKey, keys, load, t, workspace.id, workspaceServer.server]
|
||||
);
|
||||
|
||||
if (!settings) {
|
||||
return (
|
||||
<SettingHeader
|
||||
title={byokT(t, 'title-beta')}
|
||||
subtitle={byokT(t, 'loading')}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!settings.entitled) {
|
||||
return (
|
||||
<>
|
||||
<SettingHeader
|
||||
title={byokT(t, 'title-beta')}
|
||||
subtitle={byokT(t, 'subtitle')}
|
||||
/>
|
||||
<SettingWrapper>
|
||||
<div className={styles.locked} data-testid="workspace-byok-locked">
|
||||
<div>
|
||||
<div className={styles.title}>{byokT(t, 'locked.title')}</div>
|
||||
<div className={styles.description}>
|
||||
{byokT(t, 'locked.description')}
|
||||
</div>
|
||||
</div>
|
||||
<div className={styles.tags}>
|
||||
{settings.entitlementRequired.map(plan => (
|
||||
<span className={styles.tag} key={plan}>
|
||||
{plan}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</SettingWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingHeader
|
||||
title={byokT(t, 'title-beta')}
|
||||
subtitle={byokT(t, 'header')}
|
||||
/>
|
||||
<SettingWrapper>
|
||||
<div className={styles.stack}>
|
||||
{settings.hasAiPlan ? (
|
||||
<div className={styles.notice}>
|
||||
<div className={styles.title}>{byokT(t, 'notice.title')}</div>
|
||||
<div className={styles.description}>
|
||||
{byokT(t, 'notice.description')}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className={styles.panel} data-testid="workspace-byok-keys">
|
||||
<div className={styles.panelHeader}>
|
||||
<div>
|
||||
<div className={styles.title}>{byokT(t, 'keys.title')}</div>
|
||||
<div className={styles.description}>
|
||||
{byokT(t, 'keys.description')}
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="primary"
|
||||
disabled={!canManageKeys}
|
||||
onClick={() => {
|
||||
setEditingKey(null);
|
||||
setModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{byokT(t, 'action.add-key')}
|
||||
</Button>
|
||||
</div>
|
||||
{keys.length ? (
|
||||
<KeyList
|
||||
keys={keys}
|
||||
onEdit={key => {
|
||||
setEditingKey(key);
|
||||
setModalOpen(true);
|
||||
}}
|
||||
onDelete={key => {
|
||||
deleteKey(key).catch(error => {
|
||||
logByokError('Failed to delete BYOK key', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.delete-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}}
|
||||
onDragStart={key => {
|
||||
setDraggingKey({ id: key.id, storage: key.storage });
|
||||
}}
|
||||
onDragEnd={() => setDraggingKey(null)}
|
||||
onDrop={key => {
|
||||
reorderKey(key).catch(error => {
|
||||
logByokError('Failed to reorder BYOK keys', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.reorder-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<div className={styles.empty} data-testid="workspace-byok-empty">
|
||||
<div className={styles.title}>{byokT(t, 'empty.title')}</div>
|
||||
<div className={styles.description}>
|
||||
{byokT(t, 'empty.description')}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<CoveragePanel keys={keys} settings={settings} />
|
||||
|
||||
<UsagePanel
|
||||
keys={keys}
|
||||
usage={usage}
|
||||
onClearAll={() => {
|
||||
clearAll().catch(error => {
|
||||
logByokError('Failed to clear BYOK keys', error);
|
||||
notify.error({
|
||||
title: byokT(t, 'notify.clear-failed.title'),
|
||||
message: byokT(t, 'notify.operation-failed.message'),
|
||||
});
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</SettingWrapper>
|
||||
<AddKeyModal
|
||||
workspaceId={workspace.id}
|
||||
settings={settings}
|
||||
editingKey={editingKey}
|
||||
open={modalOpen}
|
||||
onOpenChange={open => {
|
||||
setModalOpen(open);
|
||||
if (!open) {
|
||||
setEditingKey(null);
|
||||
}
|
||||
}}
|
||||
onSaved={load}
|
||||
localKeys={localKeys}
|
||||
setLocalKeys={setLocalKeys}
|
||||
localStorageSupported={settings.localStorageSupported}
|
||||
canAddServerKey={canAddServerKey}
|
||||
canAddLocalKey={canAddLocalKey}
|
||||
gql={workspaceServer.server?.gql as GqlFn | undefined}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,92 @@
|
||||
import { DragHandle, IconButton } from '@affine/component';
|
||||
import { useI18n } from '@affine/i18n';
|
||||
import { DeleteIcon, EditIcon } from '@blocksuite/icons/rc';
|
||||
import type { DragEvent } from 'react';
|
||||
|
||||
import * as styles from './index.css';
|
||||
import {
|
||||
byokT,
|
||||
capabilityLabel,
|
||||
providerLabels,
|
||||
rowDescription,
|
||||
storageLabel,
|
||||
} from './metadata';
|
||||
import type { ByokKey } from './types';
|
||||
|
||||
export const KeyList = ({
|
||||
keys,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onDragStart,
|
||||
onDragEnd,
|
||||
onDrop,
|
||||
}: {
|
||||
keys: ByokKey[];
|
||||
onEdit: (key: ByokKey) => void;
|
||||
onDelete: (key: ByokKey) => void;
|
||||
onDragStart: (key: ByokKey) => void;
|
||||
onDragEnd: () => void;
|
||||
onDrop: (key: ByokKey) => void;
|
||||
}) => {
|
||||
const t = useI18n();
|
||||
|
||||
return (
|
||||
<div className={styles.rows}>
|
||||
{keys.map(key => (
|
||||
<div
|
||||
className={`${styles.row} ${key.enabled ? '' : styles.rowDisabled}`}
|
||||
draggable
|
||||
key={`${key.storage}:${key.id}`}
|
||||
onDragStart={() => onDragStart(key)}
|
||||
onDragEnd={onDragEnd}
|
||||
onDragOver={(event: DragEvent<HTMLDivElement>) => {
|
||||
event.preventDefault();
|
||||
}}
|
||||
onDrop={event => {
|
||||
event.preventDefault();
|
||||
onDrop(key);
|
||||
}}
|
||||
>
|
||||
<div className={styles.dragHandle} title={byokT(t, 'action.reorder')}>
|
||||
<DragHandle />
|
||||
</div>
|
||||
<div className={styles.rowMain}>
|
||||
<div className={styles.rowTitle}>
|
||||
{providerLabels[key.provider]} / {key.name}
|
||||
<span className={styles.tag}>{storageLabel(t, key.storage)}</span>
|
||||
{!key.enabled ? (
|
||||
<span className={`${styles.tag} ${styles.dangerTag}`}>
|
||||
{byokT(t, 'status.disabled-after-failure')}
|
||||
</span>
|
||||
) : null}
|
||||
</div>
|
||||
<div className={styles.rowDescription}>
|
||||
{rowDescription(t, key)}
|
||||
</div>
|
||||
<div className={styles.tags}>
|
||||
{key.capabilities.map(capability => (
|
||||
<span className={styles.tag} key={capability}>
|
||||
{capabilityLabel(t, capability)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className={styles.rowActions}>
|
||||
<IconButton
|
||||
size="20"
|
||||
title={byokT(t, 'action.edit')}
|
||||
icon={<EditIcon />}
|
||||
onClick={() => onEdit(key)}
|
||||
/>
|
||||
<IconButton
|
||||
size="20"
|
||||
title={byokT(t, 'action.delete')}
|
||||
icon={<DeleteIcon />}
|
||||
onClick={() => onDelete(key)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,108 @@
|
||||
import { apis } from '@affine/electron-api';
|
||||
import { ByokKeyStorage, ByokKeyTestStatus } from '@affine/graphql';
|
||||
|
||||
import { capabilitiesFor } from './metadata';
|
||||
import type { ByokKey, LocalByokKeyInput, LocalByokPublicKey } from './types';
|
||||
|
||||
function byokStorageApi() {
|
||||
return BUILD_CONFIG.isElectron ? apis?.byokStorage : undefined;
|
||||
}
|
||||
|
||||
export async function localByokStorageSupported() {
|
||||
const storage = byokStorageApi();
|
||||
if (!storage) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
return await storage.isSupported();
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function toLocalByokKey(key: LocalByokPublicKey): ByokKey {
|
||||
return {
|
||||
id: key.id,
|
||||
provider: key.provider,
|
||||
name: key.name,
|
||||
description: key.description ?? null,
|
||||
storage: ByokKeyStorage.local,
|
||||
configured: key.configured ?? true,
|
||||
enabled: key.enabled ?? true,
|
||||
endpoint: key.endpoint ?? null,
|
||||
endpointEditable: key.endpointEditable ?? false,
|
||||
sortOrder: key.sortOrder ?? 0,
|
||||
capabilities: capabilitiesFor(key.provider, ByokKeyStorage.local),
|
||||
testStatus: key.testStatus ?? ByokKeyTestStatus.passed,
|
||||
};
|
||||
}
|
||||
|
||||
export async function readLocalKeys(workspaceId: string): Promise<ByokKey[]> {
|
||||
const storage = byokStorageApi();
|
||||
if (!(await localByokStorageSupported()) || !storage) {
|
||||
return [];
|
||||
}
|
||||
try {
|
||||
const keys = (await storage.listWorkspaceKeys(
|
||||
workspaceId
|
||||
)) as LocalByokPublicKey[];
|
||||
return keys.map(toLocalByokKey);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export async function upsertLocalKey(
|
||||
workspaceId: string,
|
||||
key: LocalByokKeyInput
|
||||
) {
|
||||
const storage = byokStorageApi();
|
||||
if (!(await localByokStorageSupported()) || !storage) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return await storage.upsertWorkspaceKey(workspaceId, key);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function deleteLocalKey(workspaceId: string, keyId: string) {
|
||||
const storage = byokStorageApi();
|
||||
if (!(await localByokStorageSupported()) || !storage) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
return await storage.deleteWorkspaceKey(workspaceId, keyId);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function reorderLocalKeys(workspaceId: string, ids: string[]) {
|
||||
const storage = byokStorageApi();
|
||||
if (!(await localByokStorageSupported()) || !storage) {
|
||||
return [];
|
||||
}
|
||||
try {
|
||||
const keys = (await storage.reorderWorkspaceKeys(
|
||||
workspaceId,
|
||||
ids
|
||||
)) as LocalByokPublicKey[];
|
||||
return keys.map(toLocalByokKey);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export async function clearLocalKeys(workspaceId: string) {
|
||||
const storage = byokStorageApi();
|
||||
if (!(await localByokStorageSupported()) || !storage) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
return await storage.clearWorkspaceKeys(workspaceId);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user