feat(server): refactor for byok (#14911)

This commit is contained in:
DarkSky
2026-05-07 04:03:14 +08:00
committed by GitHub
parent 4e169ea5c7
commit eb9cc22502
115 changed files with 10369 additions and 1256 deletions

View File

@@ -995,6 +995,26 @@
"description": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>\n@default false", "description": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>\n@default false",
"default": false "default": false
}, },
"byok.enabled": {
"type": "boolean",
"description": "Whether to enable workspace BYOK.\n@default true",
"default": true
},
"byok.allowedProviders": {
"type": "array",
"description": "The allowlist for workspace BYOK providers.\n@default [\"openai\",\"anthropic\",\"gemini\",\"fal\"]",
"default": [
"openai",
"anthropic",
"gemini",
"fal"
]
},
"byok.allowCustomEndpoint": {
"type": "boolean",
"description": "Whether workspace BYOK custom endpoints are accepted.\n@default false",
"default": false
},
"providers.profiles": { "providers.profiles": {
"type": "array", "type": "array",
"description": "The profile list for copilot providers.\n@default []", "description": "The profile list for copilot providers.\n@default []",
@@ -1071,13 +1091,6 @@
}, },
"default": {} "default": {}
}, },
"providers.perplexity": {
"type": "object",
"description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\"}",
"default": {
"apiKey": ""
}
},
"providers.anthropic": { "providers.anthropic": {
"type": "object", "type": "object",
"description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.anthropic.com/v1\"}", "description": "The config for the anthropic provider.\n@default {\"apiKey\":\"\",\"baseURL\":\"https://api.anthropic.com/v1\"}",
@@ -1121,11 +1134,6 @@
}, },
"default": {} "default": {}
}, },
"providers.morph": {
"type": "object",
"description": "The config for the morph provider.\n@default {}",
"default": {}
},
"unsplash": { "unsplash": {
"type": "object", "type": "object",
"description": "The config for the unsplash key.\n@default {\"key\":\"\"}", "description": "The config for the unsplash key.\n@default {\"key\":\"\"}",

View File

@@ -364,7 +364,7 @@ export interface ModelConditionsContract {
} }
export interface ModelRegistryMatchRequest { export interface ModelRegistryMatchRequest {
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph' backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
cond: ModelConditionsContract cond: ModelConditionsContract
} }
@@ -373,7 +373,7 @@ export interface ModelRegistryMatchResponse {
} }
export interface ModelRegistryResolveRequest { export interface ModelRegistryResolveRequest {
backendKind?: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph' backendKind?: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
modelId: string modelId: string
} }
@@ -388,7 +388,7 @@ export interface ModelRegistryRouteContract {
} }
export interface ModelRegistryVariantContract { export interface ModelRegistryVariantContract {
backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph' backendKind: 'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | 'gemini_vertex' | 'fal' | 'anthropic_vertex'
canonicalKey: string canonicalKey: string
rawModelId: string rawModelId: string
displayName?: string displayName?: string

View File

@@ -686,17 +686,6 @@
} }
] ]
}, },
{
"name": "Apply Updates",
"action": "Apply Updates",
"model": "claude-sonnet-4-5@20250929",
"messages": [
{
"role": "user",
"template": "\nYou are a Markdown document update engine.\n\nYou will be given:\n\n1. content: The original Markdown document\n - The content is structured into blocks.\n - Each block starts with a comment like <!-- block_id=... flavour=... --> and contains the block's content.\n - The content is {{content}}\n\n2. op: A description of the edit intention\n - This describes the semantic meaning of the edit, such as \"Bold the first paragraph\".\n - The op is {{op}}\n\n3. updates: A Markdown snippet\n - The updates is {{updates}}\n - This represents the block-level changes to apply to the original Markdown.\n - The update may:\n - **Replace** an existing block (same block_id, new content)\n - **Delete** block(s) using <!-- delete block BLOCK_ID -->\n - **Insert** new block(s) with a new unique block_id\n - When performing deletions, the update will include **surrounding context blocks** (or use <!-- existing blocks -->) to help you determine where and what to delete.\n\nYour task:\n- Apply the update in <updates> to the document in <code>, following the intent described in <op>.\n- Preserve all block_id and flavour comments.\n- Maintain the original block order unless the update clearly appends new blocks.\n- Do not remove or alter unrelated blocks.\n- Output only the fully updated Markdown content. Do not wrap the content in ```markdown.\n\n---\n\n✍ Examples\n\n✅ Replacement (modifying an existing block)\n\n<code>\n<!-- block_id=101 flavour=paragraph -->\n## Introduction\n\n<!-- block_id=102 flavour=paragraph -->\nThis document provides an overview of the system architecture and its components.\n</code>\n\n<op>\nMake the introduction more formal.\n</op>\n\n<updates>\n<!-- block_id=102 flavour=paragraph -->\nThis document outlines the architectural design and individual components of the system in detail.\n</updates>\n\nExpected Output:\n<!-- block_id=101 flavour=paragraph -->\n## Introduction\n\n<!-- block_id=102 flavour=paragraph -->\nThis document outlines the architectural design and individual components of the system in detail.\n\n---\n\n Insertion (adding new content)\n\n<code>\n<!-- block_id=201 flavour=paragraph -->\n# Project Summary\n\n<!-- block_id=202 flavour=paragraph -->\nThis project aims to build a collaborative text editing tool.\n</code>\n\n<op>\nAdd a disclaimer section at the end.\n</op>\n\n<updates>\n<!-- block_id=new-301 flavour=paragraph -->\n## Disclaimer\n\n<!-- block_id=new-302 flavour=paragraph -->\nThis document is subject to change. Do not distribute externally.\n</updates>\n\nExpected Output:\n<!-- block_id=201 flavour=paragraph -->\n# Project Summary\n\n<!-- block_id=202 flavour=paragraph -->\nThis project aims to build a collaborative text editing tool.\n\n<!-- block_id=new-301 flavour=paragraph -->\n## Disclaimer\n\n<!-- block_id=new-302 flavour=paragraph -->\nThis document is subject to change. Do not distribute externally.\n\n---\n\n❌ Deletion (removing blocks)\n\n<code>\n<!-- block_id=401 flavour=paragraph -->\n## Author\n\n<!-- block_id=402 flavour=paragraph -->\nWritten by the AI team at OpenResearch.\n\n<!-- block_id=403 flavour=paragraph -->\n## Experimental Section\n\n<!-- block_id=404 flavour=paragraph -->\nThe following section is still under development and may change without notice.\n\n<!-- block_id=405 flavour=paragraph -->\n## License\n\n<!-- block_id=406 flavour=paragraph -->\nThis document is licensed under CC BY-NC 4.0.\n</code>\n\n<op>\nRemove the experimental section.\n</op>\n\n<updates>\n<!-- delete block_id=403 -->\n<!-- delete block_id=404 -->\n</updates>\n\nExpected Output:\n<!-- block_id=401 flavour=paragraph -->\n## Author\n\n<!-- block_id=402 flavour=paragraph -->\nWritten by the AI team at OpenResearch.\n\n<!-- block_id=405 flavour=paragraph -->\n## License\n\n<!-- block_id=406 flavour=paragraph -->\nThis document is licensed under CC BY-NC 4.0.\n\n---\n\nNow apply the `updates` to the `content`, following the intent in `op`, and return the updated Markdown.\n"
}
]
},
{ {
"name": "Code Artifact", "name": "Code Artifact",
"model": "claude-sonnet-4-5@20250929", "model": "claude-sonnet-4-5@20250929",

View File

@@ -319,7 +319,7 @@ pub struct RequestedModelMatchResponse {
pub struct ModelRegistryResolveRequest { pub struct ModelRegistryResolveRequest {
#[napi( #[napi(
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \ ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'" 'gemini_vertex' | 'fal' | 'anthropic_vertex'"
)] )]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub backend_kind: Option<String>, pub backend_kind: Option<String>,
@@ -333,7 +333,7 @@ pub struct ModelRegistryResolveRequest {
pub struct ModelRegistryMatchRequest { pub struct ModelRegistryMatchRequest {
#[napi( #[napi(
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \ ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'" 'gemini_vertex' | 'fal' | 'anthropic_vertex'"
)] )]
pub backend_kind: String, pub backend_kind: String,
pub cond: ModelConditionsContract, pub cond: ModelConditionsContract,
@@ -346,7 +346,7 @@ pub struct ModelRegistryMatchRequest {
pub struct ModelRegistryVariantContract { pub struct ModelRegistryVariantContract {
#[napi( #[napi(
ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \ ts_type = "'openai_chat' | 'openai_responses' | 'anthropic' | 'cloudflare_workers_ai' | 'gemini_api' | \
'gemini_vertex' | 'fal' | 'perplexity' | 'anthropic_vertex' | 'morph'" 'gemini_vertex' | 'fal' | 'anthropic_vertex'"
)] )]
pub backend_kind: String, pub backend_kind: String,
pub canonical_key: String, pub canonical_key: String,

View File

@@ -7,7 +7,7 @@ pub(crate) use error::{
STREAM_ABORTED_REASON, STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, callback_dispatch_failed_reason, STREAM_ABORTED_REASON, STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, callback_dispatch_failed_reason,
invalid_arg, invalid_arg,
}; };
pub(crate) use stream::emit_error_event; pub(crate) use stream::{emit_error_event, emit_provider_selected_event};
pub use stream::{ pub use stream::{
llm_dispatch_prepared_stream, llm_dispatch_tool_loop_stream, llm_dispatch_tool_loop_stream_prepared, llm_dispatch_prepared_stream, llm_dispatch_tool_loop_stream, llm_dispatch_tool_loop_stream_prepared,
llm_dispatch_tool_loop_stream_routed, llm_dispatch_tool_loop_stream_routed,

View File

@@ -106,14 +106,18 @@ fn spawn_prepared_stream(
if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON) if reason.starts_with(STREAM_CALLBACK_DISPATCH_FAILED_REASON)
); );
if let Err(error) = result if let Err(error) = &result
&& !aborted_in_worker.load(Ordering::Relaxed) && !aborted_in_worker.load(Ordering::Relaxed)
&& !callback_dispatch_failed && !callback_dispatch_failed
&& !is_abort_error(&error) && !is_abort_error(error)
{ {
emit_error_event(&callback, error.to_string(), "dispatch_error"); emit_error_event(&callback, error.to_string(), "dispatch_error");
} }
if let Ok(provider_id) = result {
emit_provider_selected_event(&callback, provider_id);
}
if !callback_dispatch_failed { if !callback_dispatch_failed {
let _ = callback.call( let _ = callback.call(
Ok(STREAM_END_MARKER.to_string()), Ok(STREAM_END_MARKER.to_string()),
@@ -129,7 +133,7 @@ fn dispatch_prepared_stream_with_fallback(
routes: &[PreparedDispatchRoute], routes: &[PreparedDispatchRoute],
callback: &ThreadsafeFunction<String, ()>, callback: &ThreadsafeFunction<String, ()>,
aborted: &AtomicBool, aborted: &AtomicBool,
) -> std::result::Result<(), BackendError> { ) -> std::result::Result<String, BackendError> {
dispatch_prepared_stream_with_fallback_using_client(&DefaultHttpClient::default(), routes, aborted, |event| { dispatch_prepared_stream_with_fallback_using_client(&DefaultHttpClient::default(), routes, aborted, |event| {
emit_stream_event(callback, event) emit_stream_event(callback, event)
}) })
@@ -140,7 +144,7 @@ fn dispatch_prepared_stream_with_fallback_using_client<F>(
routes: &[PreparedDispatchRoute], routes: &[PreparedDispatchRoute],
aborted: &AtomicBool, aborted: &AtomicBool,
mut emit_event: F, mut emit_event: F,
) -> std::result::Result<(), BackendError> ) -> std::result::Result<String, BackendError>
where where
F: FnMut(&StreamEvent) -> Status, F: FnMut(&StreamEvent) -> Status,
{ {
@@ -154,7 +158,7 @@ where
.collect::<std::result::Result<Vec<_>, BackendError>>()?; .collect::<std::result::Result<Vec<_>, BackendError>>()?;
let mut callback_dispatch_failed = false; let mut callback_dispatch_failed = false;
dispatch_prepared_stream_with_pipeline( let provider_id = dispatch_prepared_stream_with_pipeline(
client, client,
&mut adapter_routes, &mut adapter_routes,
|| aborted.load(Ordering::Relaxed), || aborted.load(Ordering::Relaxed),
@@ -174,7 +178,7 @@ where
"{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:unknown" "{STREAM_CALLBACK_DISPATCH_FAILED_REASON}:unknown"
))) )))
} else { } else {
Ok(()) Ok(provider_id)
} }
} }
@@ -195,6 +199,16 @@ pub(crate) fn emit_error_event(callback: &ThreadsafeFunction<String, ()>, messag
let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking); let _ = callback.call(Ok(error_event), ThreadsafeFunctionCallMode::NonBlocking);
} }
pub(crate) fn emit_provider_selected_event(callback: &ThreadsafeFunction<String, ()>, provider_id: String) {
let event = serde_json::json!({
"type": "provider_selected",
"provider_id": provider_id,
})
.to_string();
let _ = callback.call(Ok(event), ThreadsafeFunctionCallMode::NonBlocking);
}
fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status { fn emit_stream_event(callback: &ThreadsafeFunction<String, ()>, event: &StreamEvent) -> Status {
let value = serde_json::to_string(event).unwrap_or_else(|error| { let value = serde_json::to_string(event).unwrap_or_else(|error| {
serde_json::json!({ serde_json::json!({

View File

@@ -14,7 +14,10 @@ use napi::{
threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode}, threadsafe_function::{ThreadsafeFunction, ThreadsafeFunctionCallMode},
}; };
use super::callback::{NapiEventSink, NapiToolExecutor, emit_tool_loop_event}; use super::{
super::emit_provider_selected_event,
callback::{NapiEventSink, NapiToolExecutor, emit_tool_loop_event},
};
use crate::llm::{ use crate::llm::{
LlmDispatchPayload, LlmMiddlewarePayload, LlmStreamHandle, STREAM_ABORTED_REASON, LlmDispatchPayload, LlmMiddlewarePayload, LlmStreamHandle, STREAM_ABORTED_REASON,
STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, StreamPipeline, apply_request_middlewares, STREAM_CALLBACK_DISPATCH_FAILED_REASON, STREAM_END_MARKER, StreamPipeline, apply_request_middlewares,
@@ -39,11 +42,13 @@ fn dispatch_prepared_round_with_fallback(
}) })
.collect::<std::result::Result<Vec<_>, BackendError>>()?; .collect::<std::result::Result<Vec<_>, BackendError>>()?;
run_prepared_stream_round_with_fallback( let mut selected_provider_id: Option<String> = None;
let outcome = run_prepared_stream_round_with_fallback(
&mut pipelines, &mut pipelines,
|on_event| { |on_event| {
let (selected_index, _) = let (selected_index, provider_id) =
dispatch_prepared_stream_with_fallback_index(&DefaultHttpClient::default(), &adapter_routes, on_event)?; dispatch_prepared_stream_with_fallback_index(&DefaultHttpClient::default(), &adapter_routes, on_event)?;
selected_provider_id = Some(provider_id);
Ok(selected_index) Ok(selected_index)
}, },
|| aborted.load(Ordering::Relaxed), || aborted.load(Ordering::Relaxed),
@@ -53,7 +58,11 @@ fn dispatch_prepared_round_with_fallback(
emitted.store(true, Ordering::Relaxed); emitted.store(true, Ordering::Relaxed);
emit_tool_loop_event(callback, loop_event) emit_tool_loop_event(callback, loop_event)
}, },
) )?;
if let Some(provider_id) = selected_provider_id {
emit_provider_selected_event(callback, provider_id);
}
Ok(outcome)
} }
fn prepare_tool_loop_route( fn prepare_tool_loop_route(

View File

@@ -0,0 +1,73 @@
-- CreateTable
CREATE TABLE "ai_workspace_byok_configs" (
"id" VARCHAR NOT NULL,
"workspace_id" VARCHAR NOT NULL,
"provider" VARCHAR NOT NULL,
"name" VARCHAR NOT NULL,
"description" VARCHAR,
"encrypted_api_key" TEXT NOT NULL,
"endpoint" TEXT,
"sort_order" INTEGER NOT NULL DEFAULT 0,
"enabled" BOOLEAN NOT NULL DEFAULT true,
"disabled_reason" VARCHAR,
"last_validated_at" TIMESTAMPTZ(3),
"last_validation_error" TEXT,
"last_used_at" TIMESTAMPTZ(3),
"last_error_at" TIMESTAMPTZ(3),
"last_error" TEXT,
"created_by" VARCHAR,
"updated_by" VARCHAR,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMPTZ(3) NOT NULL,
CONSTRAINT "ai_workspace_byok_configs_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "ai_usage_events" (
"id" VARCHAR NOT NULL,
"workspace_id" VARCHAR NOT NULL,
"user_id" VARCHAR,
"provider" VARCHAR NOT NULL,
"provider_source" VARCHAR NOT NULL,
"feature_kind" VARCHAR NOT NULL,
"model" VARCHAR,
"session_id" VARCHAR,
"task_id" VARCHAR,
"action_id" VARCHAR,
"billing_unit_id" VARCHAR,
"prompt_tokens" INTEGER NOT NULL DEFAULT 0,
"completion_tokens" INTEGER NOT NULL DEFAULT 0,
"total_tokens" INTEGER NOT NULL DEFAULT 0,
"cached_tokens" INTEGER NOT NULL DEFAULT 0,
"created_at" TIMESTAMPTZ(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "ai_usage_events_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "ai_workspace_byok_configs_workspace_id_idx" ON "ai_workspace_byok_configs"("workspace_id");
-- CreateIndex
CREATE INDEX "ai_workspace_byok_configs_workspace_id_provider_enabled_sor_idx" ON "ai_workspace_byok_configs"("workspace_id", "provider", "enabled", "sort_order");
-- CreateIndex
CREATE UNIQUE INDEX "ai_workspace_byok_configs_workspace_id_provider_name_key" ON "ai_workspace_byok_configs"("workspace_id", "provider", "name");
-- CreateIndex
CREATE INDEX "ai_usage_events_workspace_id_created_at_idx" ON "ai_usage_events"("workspace_id", "created_at");
-- CreateIndex
CREATE INDEX "ai_usage_events_workspace_id_provider_source_created_at_idx" ON "ai_usage_events"("workspace_id", "provider_source", "created_at");
-- CreateIndex
CREATE INDEX "ai_usage_events_feature_kind_created_at_idx" ON "ai_usage_events"("feature_kind", "created_at");
-- CreateIndex
CREATE INDEX "ai_usage_events_quota_exempt_idx" ON "ai_usage_events"("user_id", "provider_source", "feature_kind", "billing_unit_id");
-- AddForeignKey
ALTER TABLE "ai_workspace_byok_configs" ADD CONSTRAINT "ai_workspace_byok_configs_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "ai_usage_events" ADD CONSTRAINT "ai_usage_events_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;

View File

@@ -147,6 +147,8 @@ model Workspace {
blobs Blob[] blobs Blob[]
ignoredDocs AiWorkspaceIgnoredDocs[] ignoredDocs AiWorkspaceIgnoredDocs[]
embedFiles AiWorkspaceFiles[] embedFiles AiWorkspaceFiles[]
byokConfigs AiWorkspaceByokConfig[]
aiUsageEvents AiUsageEvent[]
comments Comment[] comments Comment[]
commentAttachments CommentAttachment[] commentAttachments CommentAttachment[]
workspaceCalendars WorkspaceCalendar[] workspaceCalendars WorkspaceCalendar[]
@@ -558,27 +560,27 @@ model AiSession {
} }
model AiActionRun { model AiActionRun {
id String @id @default(uuid()) @db.VarChar id String @id @default(uuid()) @db.VarChar
userId String @map("user_id") @db.VarChar userId String @map("user_id") @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar workspaceId String @map("workspace_id") @db.VarChar
docId String? @map("doc_id") @db.VarChar docId String? @map("doc_id") @db.VarChar
sessionId String? @map("session_id") @db.VarChar sessionId String? @map("session_id") @db.VarChar
userMessageId String? @map("user_message_id") @db.VarChar userMessageId String? @map("user_message_id") @db.VarChar
compatSubmissionId String? @map("compat_submission_id") @db.VarChar compatSubmissionId String? @map("compat_submission_id") @db.VarChar
assistantMessageId String? @map("assistant_message_id") @db.VarChar assistantMessageId String? @map("assistant_message_id") @db.VarChar
actionId String @map("action_id") @db.VarChar actionId String @map("action_id") @db.VarChar
actionVersion String @map("action_version") @db.VarChar actionVersion String @map("action_version") @db.VarChar
status String @db.VarChar status String @db.VarChar
attempt Int @default(1) attempt Int @default(1)
retryOf String? @map("retry_of") @db.VarChar retryOf String? @map("retry_of") @db.VarChar
inputSnapshot Json? @map("input_snapshot") @db.Json inputSnapshot Json? @map("input_snapshot") @db.Json
result Json? @db.Json result Json? @db.Json
artifacts Json? @db.Json artifacts Json? @db.Json
resultSummary String? @map("result_summary") @db.Text resultSummary String? @map("result_summary") @db.Text
errorCode String? @map("error_code") @db.VarChar errorCode String? @map("error_code") @db.VarChar
trace Json? @db.Json trace Json? @db.Json
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3) createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3) updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
session AiSession? @relation(fields: [sessionId], references: [id], onDelete: SetNull) session AiSession? @relation(fields: [sessionId], references: [id], onDelete: SetNull)
@@ -732,6 +734,62 @@ model AiWorkspaceBlobEmbedding {
@@map("ai_workspace_blob_embeddings") @@map("ai_workspace_blob_embeddings")
} }
model AiWorkspaceByokConfig {
id String @id @default(uuid()) @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
provider String @db.VarChar
name String @db.VarChar
description String? @db.VarChar
encryptedApiKey String @map("encrypted_api_key") @db.Text
endpoint String? @db.Text
sortOrder Int @default(0) @map("sort_order")
enabled Boolean @default(true)
disabledReason String? @map("disabled_reason") @db.VarChar
lastValidatedAt DateTime? @map("last_validated_at") @db.Timestamptz(3)
lastValidationError String? @map("last_validation_error") @db.Text
lastUsedAt DateTime? @map("last_used_at") @db.Timestamptz(3)
lastErrorAt DateTime? @map("last_error_at") @db.Timestamptz(3)
lastError String? @map("last_error") @db.Text
createdBy String? @map("created_by") @db.VarChar
updatedBy String? @map("updated_by") @db.VarChar
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(3)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
@@unique([workspaceId, provider, name])
@@index([workspaceId])
@@index([workspaceId, provider, enabled, sortOrder])
@@map("ai_workspace_byok_configs")
}
model AiUsageEvent {
id String @id @default(uuid()) @db.VarChar
workspaceId String @map("workspace_id") @db.VarChar
userId String? @map("user_id") @db.VarChar
provider String @db.VarChar
providerSource String @map("provider_source") @db.VarChar
featureKind String @map("feature_kind") @db.VarChar
model String? @db.VarChar
sessionId String? @map("session_id") @db.VarChar
taskId String? @map("task_id") @db.VarChar
actionId String? @map("action_id") @db.VarChar
billingUnitId String? @map("billing_unit_id") @db.VarChar
promptTokens Int @default(0) @map("prompt_tokens")
completionTokens Int @default(0) @map("completion_tokens")
totalTokens Int @default(0) @map("total_tokens")
cachedTokens Int @default(0) @map("cached_tokens")
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(3)
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
@@index([workspaceId, createdAt])
@@index([workspaceId, providerSource, createdAt])
@@index([featureKind, createdAt])
@@index([userId, providerSource, featureKind, billingUnitId], map: "ai_usage_events_quota_exempt_idx")
@@map("ai_usage_events")
}
enum AiJobStatus { enum AiJobStatus {
pending pending
running running

View File

@@ -526,17 +526,6 @@ Generated by [AVA](https://avajs.dev).
remoteAttachmentRequests: [], remoteAttachmentRequests: [],
} }
## PerplexityProvider should ignore attachments during text model matching
> Snapshot 1
[
{
text: 'summarize this',
type: 'text',
},
]
## GeminiVertexProvider should prefetch bearer token for native config ## GeminiVertexProvider should prefetch bearer token for native config
> Snapshot 1 > Snapshot 1

File diff suppressed because it is too large Load Diff

View File

@@ -771,7 +771,7 @@ function actionRunRecord(
}; };
} }
function installActionSessionMock( async function installActionSessionMock(
t: ExecutionContext<Tester>, t: ExecutionContext<Tester>,
{ {
actionId, actionId,
@@ -786,8 +786,12 @@ function installActionSessionMock(
const { models, session } = t.context; const { models, session } = t.context;
const sandbox = Sinon.createSandbox(); const sandbox = Sinon.createSandbox();
const sessionId = `copilot-provider-action-${actionId}-${randomUUID()}`; const sessionId = `copilot-provider-action-${actionId}-${randomUUID()}`;
const userId = `copilot-provider-user-${randomUUID()}`; const user = await models.user.create({
const workspaceId = `copilot-provider-action-${actionId}`; email: `copilot-provider-user-${randomUUID()}@affine.test`,
});
const userId = user.id;
const workspace = await models.workspace.create(userId);
const workspaceId = workspace.id;
const docId = `copilot-provider-action-${actionId}-doc`; const docId = `copilot-provider-action-${actionId}-doc`;
const savedTurns: Array<{ role: string }> = []; const savedTurns: Array<{ role: string }> = [];
const userTurn = { const userTurn = {
@@ -904,7 +908,11 @@ for (const { actionId, content, verifier } of actionRecipeCases) {
} }
const { sandbox, sessionId, userId, savedTurns } = const { sandbox, sessionId, userId, savedTurns } =
installActionSessionMock(t, { actionId, actionPrompt, content }); await installActionSessionMock(t, {
actionId,
actionPrompt,
content,
});
let result = ''; let result = '';
try { try {
@@ -976,8 +984,10 @@ for (const testCase of TRANSCRIPT_AUDIO_CASES) {
runIfCopilotConfigured, runIfCopilotConfigured,
async t => { async t => {
const { models, transcript } = t.context; const { models, transcript } = t.context;
const userId = `copilot-provider-transcript-user-${randomUUID()}`; const user = await models.user.create({
const workspaceId = `copilot-provider-transcript-workspace-${randomUUID()}`; email: `copilot-provider-transcript-${randomUUID()}@affine.pro`,
});
const workspace = await models.workspace.create(user.id);
const blobId = `copilot-provider-transcript-blob-${randomUUID()}`; const blobId = `copilot-provider-transcript-blob-${randomUUID()}`;
const payload = TranscriptPayloadSchema.parse({ const payload = TranscriptPayloadSchema.parse({
sourceAudio: { blobId, mimeType: testCase.mimeType }, sourceAudio: { blobId, mimeType: testCase.mimeType },
@@ -990,8 +1000,8 @@ for (const testCase of TRANSCRIPT_AUDIO_CASES) {
], ],
}); });
const task = await models.copilotTranscriptTask.create({ const task = await models.copilotTranscriptTask.create({
userId, userId: user.id,
workspaceId, workspaceId: workspace.id,
blobId, blobId,
strategy: 'gemini', strategy: 'gemini',
recipeId: 'transcript.audio.gemini', recipeId: 'transcript.audio.gemini',

View File

@@ -139,9 +139,6 @@ test.before(async t => {
fal: { fal: {
apiKey: process.env.COPILOT_FAL_API_KEY ?? '1', apiKey: process.env.COPILOT_FAL_API_KEY ?? '1',
}, },
perplexity: {
apiKey: process.env.COPILOT_PERPLEXITY_API_KEY ?? '1',
},
anthropic: { anthropic: {
apiKey: process.env.COPILOT_ANTHROPIC_API_KEY ?? '1', apiKey: process.env.COPILOT_ANTHROPIC_API_KEY ?? '1',
}, },

View File

@@ -1,11 +1,15 @@
import test from 'ava'; import test from 'ava';
import Sinon from 'sinon'; import Sinon from 'sinon';
import type { Models } from '../../models'; import { type Models } from '../../models';
import { CopilotAccessPolicy } from '../../plugins/copilot/access';
import type { ByokFeatureKind } from '../../plugins/copilot/byok/types';
import { HistoryAttachmentUrlProjector } from '../../plugins/copilot/compat/history-attachment-url-projector'; import { HistoryAttachmentUrlProjector } from '../../plugins/copilot/compat/history-attachment-url-projector';
import { CompatHistoryProjector } from '../../plugins/copilot/compat/history-projector'; import { CompatHistoryProjector } from '../../plugins/copilot/compat/history-projector';
import { HistoryPromptPreloadProjector } from '../../plugins/copilot/compat/history-prompt-preload-projector'; import { HistoryPromptPreloadProjector } from '../../plugins/copilot/compat/history-prompt-preload-projector';
import { HistoryVisibilityPolicy } from '../../plugins/copilot/compat/history-visibility-policy'; import { HistoryVisibilityPolicy } from '../../plugins/copilot/compat/history-visibility-policy';
import { ConversationPolicy } from '../../plugins/copilot/conversation/policy';
import type { Turn } from '../../plugins/copilot/core';
import { CopilotEmbeddingClientService } from '../../plugins/copilot/embedding/client'; import { CopilotEmbeddingClientService } from '../../plugins/copilot/embedding/client';
import { CopilotProviderType } from '../../plugins/copilot/providers/types'; import { CopilotProviderType } from '../../plugins/copilot/providers/types';
import { import {
@@ -29,9 +33,11 @@ import {
AttachmentMaterializer, AttachmentMaterializer,
resolveAttachmentFetchUrl, resolveAttachmentFetchUrl,
} from '../../plugins/copilot/runtime/hosts/attachment-materializer'; } from '../../plugins/copilot/runtime/hosts/attachment-materializer';
import { ConversationHost } from '../../plugins/copilot/runtime/hosts/conversation-host';
import { ImageResultHost } from '../../plugins/copilot/runtime/hosts/image-result-host'; import { ImageResultHost } from '../../plugins/copilot/runtime/hosts/image-result-host';
import { ResponsePostprocessor } from '../../plugins/copilot/runtime/hosts/response-postprocessor'; import { ResponsePostprocessor } from '../../plugins/copilot/runtime/hosts/response-postprocessor';
import { TurnPersistence } from '../../plugins/copilot/runtime/hosts/turn-persistence'; import { TurnPersistence } from '../../plugins/copilot/runtime/hosts/turn-persistence';
import { ToolRuntime } from '../../plugins/copilot/runtime/tool-runtime';
function stubTurnPersistence( function stubTurnPersistence(
persistProjectedResult: Sinon.SinonStub = Sinon.stub().resolves(null) persistProjectedResult: Sinon.SinonStub = Sinon.stub().resolves(null)
@@ -41,6 +47,367 @@ function stubTurnPersistence(
} as unknown as TurnPersistence; } as unknown as TurnPersistence;
} }
function stubConversationSession(latestUserTurn?: unknown) {
return {
config: {
sessionId: 'session-1',
userId: 'user-1',
workspaceId: 'workspace-1',
},
model: 'gpt-4o-mini',
stashTurns: latestUserTurn ? [latestUserTurn] : [],
latestUserTurn,
revertLatestMessage: Sinon.stub(),
};
}
test('ConversationPolicy should treat zero quota limit as exhausted', async t => {
const policy = new ConversationPolicy(
{
userFeature: { has: Sinon.stub().resolves(false) },
copilotSession: { countUserMessages: Sinon.stub().resolves(0) },
} as any,
{
getUserQuota: Sinon.stub().resolves({ copilotActionLimit: 0 }),
} as any
);
t.false(await policy.hasQuota('user-1'));
await t.throwsAsync(policy.checkQuota('user-1'));
});
type TurnRouteAccessCase = {
name: string;
profiles: Array<{ id: string }>;
featureKind?: 'embedding' | 'rerank' | 'workspace_indexing';
byokLeaseId?: string;
quotaBackedRoutesAllowed?: boolean;
expectedQuotaCalls: number;
expectedError?: string;
expectedQuotaBackedRoutesAllowed?: boolean;
};
const turnRouteAccessCases: TurnRouteAccessCase[] = [
{
name: 'checks quota when BYOK does not cover the route',
profiles: [],
expectedQuotaCalls: 1,
expectedError: 'quota exceeded',
},
{
name: 'skips quota when BYOK covers the route',
profiles: [{ id: 'profile-1' }],
byokLeaseId: 'lease-1',
expectedQuotaCalls: 0,
expectedQuotaBackedRoutesAllowed: undefined,
},
{
name: 'preserves explicit quota-backed route disable override',
profiles: [],
quotaBackedRoutesAllowed: false,
expectedQuotaCalls: 0,
expectedQuotaBackedRoutesAllowed: false,
},
{
name: 'does not check user quota for unmetered service features',
profiles: [],
featureKind: 'rerank',
expectedQuotaCalls: 0,
expectedQuotaBackedRoutesAllowed: true,
},
];
for (const matrixCase of turnRouteAccessCases) {
test(`CopilotAccessPolicy resolve turn route access: ${matrixCase.name}`, async t => {
const checkQuota = Sinon.stub().rejects(new Error('quota exceeded'));
const getProfiles = Sinon.stub().resolves(matrixCase.profiles);
const access = new CopilotAccessPolicy(
{ checkQuota } as any,
{ getProfiles } as any
);
const promise = access.resolveTurnRouteAccess({
userId: 'user-1',
workspaceId: 'workspace-1',
byokLeaseId: matrixCase.byokLeaseId,
featureKind: matrixCase.featureKind,
quotaBackedRoutesAllowed: matrixCase.quotaBackedRoutesAllowed,
});
if (matrixCase.expectedError) {
await t.throwsAsync(promise, { message: matrixCase.expectedError });
} else {
const routeAccess = await promise;
t.is(
routeAccess.quotaBackedRoutesAllowed,
matrixCase.expectedQuotaBackedRoutesAllowed
);
}
t.is(checkQuota.callCount, matrixCase.expectedQuotaCalls);
if (matrixCase.expectedQuotaCalls) {
Sinon.assert.calledWithExactly(checkQuota, 'user-1');
}
if (matrixCase.byokLeaseId) {
Sinon.assert.calledWithMatch(getProfiles, {
byokLeaseId: matrixCase.byokLeaseId,
});
}
});
}
type ByokCoverageCase = {
featureKind?: ByokFeatureKind;
expected: { local: boolean; server: boolean };
};
const byokCoverageCases: ByokCoverageCase[] = [
{ featureKind: 'chat', expected: { local: true, server: true } },
{ featureKind: 'action', expected: { local: true, server: true } },
{ featureKind: 'image', expected: { local: true, server: true } },
{ featureKind: 'transcript', expected: { local: false, server: true } },
{ featureKind: 'embedding', expected: { local: false, server: true } },
{
featureKind: 'workspace_indexing',
expected: { local: false, server: true },
},
{ featureKind: 'rerank', expected: { local: false, server: true } },
{ expected: { local: true, server: true } },
];
for (const matrixCase of byokCoverageCases) {
test(`CopilotAccessPolicy should resolve BYOK coverage for ${matrixCase.featureKind ?? 'default'}`, async t => {
const getProfiles = Sinon.stub().resolves([]);
const access = new CopilotAccessPolicy(
{ hasQuota: Sinon.stub().resolves(true) } as any,
{ getProfiles } as any
);
await access.getByokProfiles({
userId: 'user-1',
workspaceId: 'workspace-1',
featureKind: matrixCase.featureKind,
});
t.like(getProfiles.firstCall.args[0], {
userId: 'user-1',
workspaceId: 'workspace-1',
});
t.is(getProfiles.firstCall.args[0].featureKind, matrixCase.featureKind);
t.deepEqual(getProfiles.firstCall.args[1], matrixCase.expected);
});
}
test('CopilotAccessPolicy assertQuotaOrByok should honor quota-backed route disable', async t => {
const checkQuota = Sinon.stub().resolves(undefined);
const access = new CopilotAccessPolicy(
{ checkQuota } as any,
{ getProfiles: Sinon.stub().resolves([]) } as any
);
await t.throwsAsync(
access.assertQuotaOrByok({
userId: 'user-1',
workspaceId: 'workspace-1',
featureKind: 'transcript',
quotaBackedRoutesAllowed: false,
})
);
Sinon.assert.notCalled(checkQuota);
});
test('ConversationHost should delegate empty no-message stream access', async t => {
const session = stubConversationSession();
const resolveTurnRouteAccess = Sinon.stub().rejects(
new Error('quota exceeded')
);
const host = new ConversationHost(
{
get: Sinon.stub().resolves(session),
revertLatestMessage: Sinon.stub().resolves(undefined),
} as any,
{} as any,
{} as any,
{ resolveTurnRouteAccess } as any
);
await t.throwsAsync(host.prepareTurn('user-1', 'session-1', {}), {
message: 'quota exceeded',
});
Sinon.assert.calledOnceWithMatch(resolveTurnRouteAccess, {
userId: 'user-1',
workspaceId: 'workspace-1',
});
});
test('ConversationHost should return access decision for empty no-message stream', async t => {
const session = stubConversationSession();
const resolveTurnRouteAccess = Sinon.stub().resolves({
byokProfiles: [{ id: 'profile-1' }],
quotaBackedRoutesAllowed: undefined,
});
const host = new ConversationHost(
{
get: Sinon.stub().resolves(session),
revertLatestMessage: Sinon.stub().resolves(undefined),
} as any,
{} as any,
{} as any,
{ resolveTurnRouteAccess } as any
);
const prepared = await host.prepareTurn('user-1', 'session-1', {});
t.is(prepared.latestTurn, undefined);
t.is(prepared.quotaBackedRoutesAllowed, undefined);
Sinon.assert.calledOnce(resolveTurnRouteAccess);
});
test('ConversationHost should replay accepted tokens without rechecking quota', async t => {
const acceptedTurn: Turn = {
id: 'turn-1',
conversationId: 'session-1',
role: 'user',
content: 'hello',
attachments: [],
metadata: {},
renderTrace: [],
toolEvents: [],
createdAt: new Date(),
};
const session = {
...stubConversationSession(acceptedTurn),
findTurn: Sinon.stub().withArgs('turn-1').returns(acceptedTurn),
};
const resolveTurnRouteAccess = Sinon.stub().rejects(
new Error('quota exceeded')
);
const host = new ConversationHost(
{
get: Sinon.stub().resolves(session),
revertLatestMessage: Sinon.stub().resolves(undefined),
} as any,
{
getAccepted: Sinon.stub().resolves({
sessionId: 'session-1',
turnId: 'turn-1',
}),
} as any,
{} as any,
{ resolveTurnRouteAccess } as any
);
const prepared = await host.prepareTurn('user-1', 'session-1', {
messageId: 'message-1',
});
t.is(prepared.latestTurn, acceptedTurn);
t.true(prepared.quotaBackedRoutesAllowed);
Sinon.assert.notCalled(resolveTurnRouteAccess);
});
test('ConversationHost should replay durable tokens without rechecking quota', async t => {
const durableTurn: Turn = {
id: 'turn-1',
conversationId: 'session-1',
role: 'user',
content: 'hello',
attachments: [],
metadata: {},
renderTrace: [],
toolEvents: [],
createdAt: new Date(),
};
const session = {
...stubConversationSession(durableTurn),
findTurn: Sinon.stub().withArgs('turn-1').returns(durableTurn),
pushPersistedTurn: Sinon.stub(),
};
const resolveTurnRouteAccess = Sinon.stub().rejects(
new Error('quota exceeded')
);
const markAccepted = Sinon.stub().resolves(undefined);
const host = new ConversationHost(
{
get: Sinon.stub().resolves(session),
findTurnByCompatSubmissionId: Sinon.stub().resolves(durableTurn),
revertLatestMessage: Sinon.stub().resolves(undefined),
} as any,
{
getAccepted: Sinon.stub().resolves(undefined),
markAccepted,
} as any,
{
acquire: Sinon.stub().resolves({
[Symbol.asyncDispose]: Sinon.stub().resolves(undefined),
}),
} as any,
{ resolveTurnRouteAccess } as any
);
const prepared = await host.prepareTurn('user-1', 'session-1', {
messageId: 'message-1',
});
t.is(prepared.latestTurn, durableTurn);
t.true(prepared.quotaBackedRoutesAllowed);
Sinon.assert.calledOnceWithMatch(markAccepted, 'message-1', {
sessionId: 'session-1',
turnId: 'turn-1',
});
Sinon.assert.notCalled(resolveTurnRouteAccess);
});
test('ToolRuntime should pass route context into prompt-backed tools', async t => {
const promptRuntime = {
runText: Sinon.stub().resolves('<html><body>done</body></html>'),
};
const runtime = new ToolRuntime(
{} as any,
{} as any,
{} as any,
{} as any,
{} as any,
{} as any,
promptRuntime as any,
{} as any
);
const tools = await runtime.getTools(
{
tools: ['codeArtifact'],
user: 'user-1',
session: 'session-1',
workspace: 'workspace-1',
byokLeaseId: 'lease-1',
featureKind: 'chat',
quotaBackedRoutesAllowed: false,
},
'gpt-4o-mini'
);
const result = await tools.code_artifact.execute?.(
{ title: 'Demo', userPrompt: 'build a page' },
{}
);
t.like(result as object, { title: 'Demo' });
Sinon.assert.calledOnceWithMatch(
promptRuntime.runText,
'Code Artifact',
{ content: 'build a page' },
{
providerOptions: {
user: 'user-1',
session: 'session-1',
workspace: 'workspace-1',
byokLeaseId: 'lease-1',
featureKind: 'chat',
quotaBackedRoutesAllowed: false,
},
}
);
});
test('ResponsePostprocessor should build text, object and image assistant turns', t => { test('ResponsePostprocessor should build text, object and image assistant turns', t => {
const postprocessor = new ResponsePostprocessor(); const postprocessor = new ResponsePostprocessor();
@@ -267,7 +634,7 @@ test('action result projection should map image result url to assistant attachme
t.deepEqual(turn?.attachments, ['https://example.com/final.png']); t.deepEqual(turn?.attachments, ['https://example.com/final.png']);
}); });
test('CopilotEmbeddingClientService should refresh configured client and clear unavailable client', async t => { test('CopilotEmbeddingClientService should keep dispatch client across global config refreshes', async t => {
const taskPolicy = { const taskPolicy = {
resolveEmbeddingModelId: () => 'text-embedding-3-large', resolveEmbeddingModelId: () => 'text-embedding-3-large',
}; };
@@ -288,8 +655,8 @@ test('CopilotEmbeddingClientService should refresh configured client and clear u
t.truthy(service.getClient()); t.truthy(service.getClient());
const second = await service.refresh(); const second = await service.refresh();
t.is(second, undefined); t.truthy(second);
t.is(service.getClient(), undefined); t.is(service.getClient(), second);
Sinon.assert.calledTwice(runtime.embeddingConfigured); Sinon.assert.calledTwice(runtime.embeddingConfigured);
Sinon.assert.alwaysCalledWithExactly( Sinon.assert.alwaysCalledWithExactly(
runtime.embeddingConfigured, runtime.embeddingConfigured,
@@ -297,6 +664,95 @@ test('CopilotEmbeddingClientService should refresh configured client and clear u
); );
}); });
test('CopilotEmbeddingClientService should keep workspace-routed embedding client without global provider', async t => {
const taskPolicy = {
resolveEmbeddingModelId: () => 'gemini-embedding-001',
resolveRerankModelId: () => 'gpt-4o-mini',
};
const runtime = {
embeddingConfigured: Sinon.stub().resolves(false),
};
const service = new CopilotEmbeddingClientService(
taskPolicy as any,
runtime as any
);
const client = await service.refresh();
t.truthy(client);
t.is(service.getClient(), client);
Sinon.assert.calledOnceWithExactly(
runtime.embeddingConfigured,
'gemini-embedding-001'
);
});
test('CopilotEmbeddingClientService should pass workspace context into embedding routes', async t => {
const signal = new AbortController().signal;
const taskPolicy = {
resolveEmbeddingModelId: () => 'gemini-embedding-001',
resolveRerankModelId: () => 'gpt-4o-mini',
};
const runtime = {
embeddingConfigured: Sinon.stub().resolves(true),
embed: Sinon.stub().resolves([[0.1]]),
rerank: Sinon.stub().resolves([0.8]),
};
const service = new CopilotEmbeddingClientService(
taskPolicy as any,
runtime as any
);
const client = await service.refresh();
t.truthy(client);
await client?.getEmbeddings(['hello'], {
workspaceId: 'workspace-1',
userId: 'user-1',
featureKind: 'workspace_indexing',
signal,
});
Sinon.assert.calledOnceWithMatch(
runtime.embed,
'gemini-embedding-001',
['hello'],
{
dimensions: Sinon.match.number,
workspace: 'workspace-1',
user: 'user-1',
featureKind: 'workspace_indexing',
signal,
}
);
await client?.reRank(
'hello',
[{ chunk: 0, content: 'hello', distance: 0.2 }],
1,
{
workspaceId: 'workspace-1',
userId: 'user-1',
featureKind: 'workspace_indexing',
signal,
}
);
Sinon.assert.calledOnceWithMatch(
runtime.rerank,
'gpt-4o-mini',
{
query: 'hello',
candidates: [{ id: '0', text: 'hello' }],
},
{
workspace: 'workspace-1',
user: 'user-1',
featureKind: 'rerank',
signal,
}
);
});
test('CompatHistoryProjector should compose visibility, prompt preload and attachment url projection', t => { test('CompatHistoryProjector should compose visibility, prompt preload and attachment url projection', t => {
const projector = new CompatHistoryProjector( const projector = new CompatHistoryProjector(
new HistoryVisibilityPolicy(), new HistoryVisibilityPolicy(),

View File

@@ -20,7 +20,6 @@ import {
import { GeminiProvider } from '../../plugins/copilot/providers/gemini/gemini'; import { GeminiProvider } from '../../plugins/copilot/providers/gemini/gemini';
import { GeminiVertexProvider } from '../../plugins/copilot/providers/gemini/vertex'; import { GeminiVertexProvider } from '../../plugins/copilot/providers/gemini/vertex';
import { OpenAIProvider } from '../../plugins/copilot/providers/openai'; import { OpenAIProvider } from '../../plugins/copilot/providers/openai';
import { PerplexityProvider } from '../../plugins/copilot/providers/perplexity';
import { import {
CopilotProviderType, CopilotProviderType,
type PromptMessage, type PromptMessage,
@@ -589,16 +588,6 @@ class TestOpenAIProvider extends OpenAIProvider {
} }
} }
class TestPerplexityProvider extends PerplexityProvider {
override get config() {
return { apiKey: 'perplexity-key' };
}
override configured() {
return true;
}
}
test('NativeProviderAdapter should append citation and attachment footnotes', async t => { test('NativeProviderAdapter should append citation and attachment footnotes', async t => {
const dispatch = () => const dispatch = () =>
(async function* (): AsyncIterableIterator<LlmToolLoopStreamEvent> { (async function* (): AsyncIterableIterator<LlmToolLoopStreamEvent> {
@@ -818,6 +807,91 @@ test('NativeProviderAdapter streamObject should map tool and text events', async
t.snapshot(events); t.snapshot(events);
}); });
test('NativeProviderAdapter streamObject should finalize usage with selected provider', async t => {
const usageEvents: Array<{
providerId: string;
model?: string;
usage?: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
cached_tokens?: number;
};
}> = [];
const adapter = new NativeProviderAdapter(
() =>
stream(() => [
{ type: 'message_start', model: 'gpt-5-mini' },
{ type: 'text_delta', text: 'ok' },
{
type: 'done',
finish_reason: 'stop',
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
},
{
type: 'provider_selected',
provider_id: 'byok-aaaaaaaaaaaa-openai-server-key1',
},
]),
{
onUsage: input => {
usageEvents.push(input);
},
}
);
const events = await collectChunks(
adapter.streamObject({
model: 'gpt-5-mini',
stream: true,
messages: nativeMessages(nativeUserText('hi')),
})
);
t.deepEqual(events, [{ type: 'text-delta', textDelta: 'ok' }]);
t.deepEqual(usageEvents, [
{
providerId: 'byok-aaaaaaaaaaaa-openai-server-key1',
model: 'gpt-5-mini',
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
},
]);
});
test('NativeProviderAdapter streamObject should keep streaming when usage callback fails', async t => {
const adapter = new NativeProviderAdapter(
() =>
stream(() => [
{ type: 'message_start', model: 'gpt-5-mini' },
{ type: 'text_delta', text: 'ok' },
{
type: 'done',
finish_reason: 'stop',
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
},
{
type: 'provider_selected',
provider_id: 'byok-aaaaaaaaaaaa-openai-server-key1',
},
]),
{
onUsage: () => {
throw new Error('usage callback failed');
},
}
);
const events = await collectChunks(
adapter.streamObject({
model: 'gpt-5-mini',
stream: true,
messages: nativeMessages(nativeUserText('hi')),
})
);
t.deepEqual(events, [{ type: 'text-delta', textDelta: 'ok' }]);
});
test('NativeRuntimeAdapter streamObject should keep raw runtime stream objects only', async t => { test('NativeRuntimeAdapter streamObject should keep raw runtime stream objects only', async t => {
const adapter = new NativeRuntimeAdapter( const adapter = new NativeRuntimeAdapter(
createTestToolLoopBridge(mockDispatch, {}, 3) createTestToolLoopBridge(mockDispatch, {}, 3)
@@ -1653,36 +1727,6 @@ test('GeminiProvider should not pass materialized inline attachment URL to nativ
t.false('url' in (attachmentPart?.source ?? {})); t.false('url' in (attachmentPart?.source ?? {}));
}); });
test('PerplexityProvider should ignore attachments during text model matching', async t => {
const provider = new TestPerplexityProvider();
let capturedRequest: LlmRequest | undefined;
(provider as any).getActiveProviderMiddleware = () => ({});
(provider as any).getTools = async () => ({});
(provider as any).createNativeAdapter = () => ({
text: async (request: LlmRequest) => {
capturedRequest = request;
return 'ok';
},
});
const result = await getProviderRuntimeHost(provider).run.text(
{ modelId: 'sonar' },
[
{
role: 'user',
content: 'summarize this',
attachments: ['https://example.com/a.pdf'],
params: { mimetype: 'application/pdf' },
},
],
{}
);
t.is(result, 'ok');
t.snapshot(capturedRequest?.messages[0]?.content);
});
test('GeminiProvider should reject unsupported attachment schemes at input validation', async t => { test('GeminiProvider should reject unsupported attachment schemes at input validation', async t => {
const provider = new TestGeminiProvider(); const provider = new TestGeminiProvider();

View File

@@ -3,7 +3,11 @@ import test from 'ava';
import Sinon from 'sinon'; import Sinon from 'sinon';
import { z } from 'zod'; import { z } from 'zod';
import { CopilotPromptInvalid, NoCopilotProviderAvailable } from '../../base'; import {
CopilotPromptInvalid,
CopilotQuotaExceeded,
NoCopilotProviderAvailable,
} from '../../base';
import { import {
type LlmBackendConfig, type LlmBackendConfig,
type LlmEmbeddingRequest, type LlmEmbeddingRequest,
@@ -20,9 +24,11 @@ import {
llmResolveRequestedModelMatch, llmResolveRequestedModelMatch,
type LlmStructuredRequest, type LlmStructuredRequest,
} from '../../native'; } from '../../native';
import type { ProviderMiddlewareConfig } from '../../plugins/copilot/config'; import type {
CopilotProviderProfile,
ProviderMiddlewareConfig,
} from '../../plugins/copilot/config';
import { CopilotProviderFactory } from '../../plugins/copilot/providers/factory'; import { CopilotProviderFactory } from '../../plugins/copilot/providers/factory';
import { MorphProvider } from '../../plugins/copilot/providers/morph';
import { OpenAIProvider } from '../../plugins/copilot/providers/openai'; import { OpenAIProvider } from '../../plugins/copilot/providers/openai';
import { CopilotProvider } from '../../plugins/copilot/providers/provider'; import { CopilotProvider } from '../../plugins/copilot/providers/provider';
import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry'; import { buildProviderRegistry } from '../../plugins/copilot/providers/provider-registry';
@@ -62,6 +68,13 @@ import {
userPrompt, userPrompt,
} from './prompt-test-helper'; } from './prompt-test-helper';
function createNativeExecutionEngine() {
return new NativeExecutionEngine({
recordUsage: Sinon.stub().resolves(),
recordProviderFailure: Sinon.stub().resolves(),
} as never);
}
function structuredOptions( function structuredOptions(
schema: z.ZodTypeAny, schema: z.ZodTypeAny,
extra?: Record<string, unknown> extra?: Record<string, unknown>
@@ -541,7 +554,7 @@ test('CapabilityRuntime should defer no-route embedding plans to native engine',
}); });
test('NativeExecutionEngine should expose execute/executeStream as the single plan entrypoints', async t => { test('NativeExecutionEngine should expose execute/executeStream as the single plan entrypoints', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let dispatchCalls = 0; let dispatchCalls = 0;
let streamCalls = 0; let streamCalls = 0;
@@ -660,6 +673,279 @@ test('NativeExecutionEngine should expose execute/executeStream as the single pl
t.is(streamCalls, 1); t.is(streamCalls, 1);
}); });
test('NativeExecutionEngine should record BYOK usage when stream finalizes with selected provider', async t => {
const byok = {
recordUsage: Sinon.stub().resolves(),
};
const engine = new NativeExecutionEngine(byok as never);
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
const originalStream = (serverNativeModule as any).llmDispatchPreparedStream;
(serverNativeModule as any).llmDispatchPreparedStream = (
_routesJson: string,
callback: (error: Error | null, arg: string) => void
) => {
callback(
null,
JSON.stringify({
type: 'message_start',
model: 'gpt-5-mini',
})
);
callback(null, JSON.stringify({ type: 'text_delta', text: 'ok' }));
callback(
null,
JSON.stringify({
type: 'done',
finish_reason: 'stop',
usage: {
prompt_tokens: 2,
completion_tokens: 3,
total_tokens: 5,
},
})
);
callback(
null,
JSON.stringify({
type: 'provider_selected',
provider_id: providerId,
})
);
callback(null, '__AFFINE_LLM_STREAM_END__');
return { abort() {} };
};
t.teardown(() => {
(serverNativeModule as any).llmDispatchPreparedStream = originalStream;
});
const chunks = await collectAsync(
engine.executeStream({
nativeDispatch: {
chat: {
routes: [
nativeRoute({
providerId,
authToken: 'byok-key',
request: nativeTextRequest('hello'),
}),
],
prepared: {
route: preparedRoute({
providerId,
authToken: 'byok-key',
}),
request: nativeTextRequest('hello'),
tools: {},
postprocess: { nodeTextMiddleware: [] },
},
hasTools: false,
},
},
request: {
kind: 'streamText',
cond: { modelId: 'gpt-5-mini' },
messages: singleUserPromptMessages('hello'),
options: {
workspace: 'workspace-1',
user: 'user-1',
session: 'session-1',
featureKind: 'chat',
},
},
routePolicy: { fallbackOrder: [providerId] },
runtimePolicy: {},
attachmentPolicy: { materializeRemoteAttachments: true },
responsePostprocess: { mode: 'streamText' },
hostPersistence: { persistAssistantTurn: true, outputKind: 'streamText' },
hostContext: {},
})
);
t.deepEqual(chunks, ['ok']);
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
workspaceId: 'workspace-1',
userId: 'user-1',
sessionId: 'session-1',
featureKind: 'chat',
providerId,
model: 'gpt-5-mini',
usage: {
prompt_tokens: 2,
completion_tokens: 3,
total_tokens: 5,
},
});
});
test('NativeExecutionEngine should record plain text BYOK usage as chat by default', async t => {
const byok = {
recordUsage: Sinon.stub().resolves(),
recordProviderFailure: Sinon.stub().resolves(),
};
const engine = new NativeExecutionEngine(byok as never);
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
const originalDispatch = (serverNativeModule as any).llmDispatchPrepared;
(serverNativeModule as any).llmDispatchPrepared = () => {
return JSON.stringify({
provider_id: providerId,
response: {
id: 'chat_execute',
model: 'gpt-5-mini',
message: {
role: 'assistant',
content: [{ type: 'text', text: 'execute-ok' }],
},
usage: {
prompt_tokens: 1,
completion_tokens: 2,
total_tokens: 3,
},
finish_reason: 'stop',
},
});
};
t.teardown(() => {
(serverNativeModule as any).llmDispatchPrepared = originalDispatch;
});
const text = await engine.execute({
nativeDispatch: {
chat: {
routes: [
nativeRoute({
providerId,
authToken: 'byok-key',
request: nativeTextRequest('hello'),
}),
],
prepared: {
route: preparedRoute({
providerId,
authToken: 'byok-key',
}),
request: nativeTextRequest('hello'),
tools: {},
postprocess: { nodeTextMiddleware: [] },
},
hasTools: false,
},
},
request: {
kind: 'text',
cond: { modelId: 'gpt-5-mini' },
messages: singleUserPromptMessages('hello'),
options: {
workspace: 'workspace-1',
user: 'user-1',
session: 'session-1',
},
},
routePolicy: { fallbackOrder: [providerId] },
runtimePolicy: {},
attachmentPolicy: { materializeRemoteAttachments: true },
responsePostprocess: { mode: 'text' },
hostPersistence: { persistAssistantTurn: true, outputKind: 'text' },
hostContext: {},
});
t.is(text, 'execute-ok');
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
workspaceId: 'workspace-1',
userId: 'user-1',
sessionId: 'session-1',
featureKind: 'chat',
providerId,
model: 'gpt-5-mini',
usage: {
prompt_tokens: 1,
completion_tokens: 2,
total_tokens: 3,
},
});
});
test('NativeExecutionEngine should not fail stream when BYOK usage recording fails', async t => {
const byok = {
recordUsage: Sinon.stub().rejects(new Error('usage db down')),
};
const engine = new NativeExecutionEngine(byok as never);
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
const originalStream = (serverNativeModule as any).llmDispatchPreparedStream;
(serverNativeModule as any).llmDispatchPreparedStream = (
_routesJson: string,
callback: (error: Error | null, arg: string) => void
) => {
callback(
null,
JSON.stringify({ type: 'message_start', model: 'gpt-5-mini' })
);
callback(null, JSON.stringify({ type: 'text_delta', text: 'ok' }));
callback(
null,
JSON.stringify({
type: 'done',
finish_reason: 'stop',
usage: { prompt_tokens: 2, completion_tokens: 3, total_tokens: 5 },
})
);
callback(
null,
JSON.stringify({ type: 'provider_selected', provider_id: providerId })
);
callback(null, '__AFFINE_LLM_STREAM_END__');
return { abort() {} };
};
t.teardown(() => {
(serverNativeModule as any).llmDispatchPreparedStream = originalStream;
});
const chunks = await collectAsync(
engine.executeStream({
nativeDispatch: {
chat: {
routes: [
nativeRoute({
providerId,
authToken: 'byok-key',
request: nativeTextRequest('hello'),
}),
],
prepared: {
route: preparedRoute({ providerId, authToken: 'byok-key' }),
request: nativeTextRequest('hello'),
tools: {},
postprocess: { nodeTextMiddleware: [] },
},
hasTools: false,
},
},
request: {
kind: 'streamText',
cond: { modelId: 'gpt-5-mini' },
messages: singleUserPromptMessages('hello'),
options: {
workspace: 'workspace-1',
user: 'user-1',
session: 'session-1',
featureKind: 'chat',
},
},
routePolicy: { fallbackOrder: [providerId] },
runtimePolicy: {},
attachmentPolicy: { materializeRemoteAttachments: true },
responsePostprocess: { mode: 'streamText' },
hostPersistence: { persistAssistantTurn: true, outputKind: 'streamText' },
hostContext: {},
})
);
t.deepEqual(chunks, ['ok']);
Sinon.assert.calledOnce(byok.recordUsage);
});
test('CopilotProviderFactory should return no prepared routes when native prepare returns null', async t => { test('CopilotProviderFactory should return no prepared routes when native prepare returns null', async t => {
const provider = new DriverOnlyProvider(); const provider = new DriverOnlyProvider();
(provider as any).AFFiNEConfig = { copilot: { providers: { openai: {} } } }; (provider as any).AFFiNEConfig = { copilot: { providers: { openai: {} } } };
@@ -693,9 +979,16 @@ test('CopilotProviderFactory should return no prepared routes when native prepar
enableFeature: Sinon.stub(), enableFeature: Sinon.stub(),
disableFeature: Sinon.stub(), disableFeature: Sinon.stub(),
}; };
const access = {
resolveRouteAccess: Sinon.stub().resolves({
byokProfiles: [],
quotaBackedRoutesAvailable: true,
}),
};
const factory = new CopilotProviderFactory( const factory = new CopilotProviderFactory(
server as never, server as never,
registryService as never registryService as never,
access as never
); );
factory.register('openai-main', provider); factory.register('openai-main', provider);
@@ -923,52 +1216,6 @@ test('driver-only provider should require explicit structured response contracts
t.is(capturedRequest, undefined); t.is(capturedRequest, undefined);
}); });
test('MorphProvider should reuse the base native chat driver template', async t => {
const provider = new MorphProvider();
(provider as any).AFFiNEConfig = {
copilot: { providers: { morph: { apiKey: 'test-key' } } },
};
(provider as any).toolExecutorHost = {
createNativeAdapter: () => ({
text: async () => 'morph text',
streamText: async function* () {
yield 'morph stream';
},
streamObject: async function* () {
yield { type: 'text-delta', textDelta: 'unused' };
},
}),
getTools: async () => ({}),
};
t.is(
await getProviderRuntimeHost(provider).run.text(
{ modelId: 'morph-v3-fast' },
promptMessages(userPrompt('hello'))
),
'morph text'
);
t.deepEqual(
await collectAsync(
getProviderRuntimeHost(provider).run.streamText(
{ modelId: 'morph-v3-fast' },
promptMessages(userPrompt('hello'))
)
),
['morph stream']
);
t.is(
await getProviderRuntimeHost(provider).prepare.chat(
'streamObject',
{
modelId: 'morph-v3-fast',
},
promptMessages(userPrompt('hello'))
),
null
);
});
test('getActiveProviderMiddleware should merge defaults with profile override', t => { test('getActiveProviderMiddleware should merge defaults with profile override', t => {
const provider = createProvider({ const provider = createProvider({
rust: { request: ['clamp_max_tokens'] }, rust: { request: ['clamp_max_tokens'] },
@@ -1231,9 +1478,16 @@ test('CopilotProviderFactory should resolve legacy model ids through native regi
enableFeature: Sinon.stub(), enableFeature: Sinon.stub(),
disableFeature: Sinon.stub(), disableFeature: Sinon.stub(),
}; };
const access = {
resolveRouteAccess: Sinon.stub().resolves({
byokProfiles: [],
quotaBackedRoutesAvailable: true,
}),
};
const factory = new CopilotProviderFactory( const factory = new CopilotProviderFactory(
server as never, server as never,
registryService as never registryService as never,
access as never
); );
factory.register('openai-main', provider); factory.register('openai-main', provider);
@@ -1242,6 +1496,230 @@ test('CopilotProviderFactory should resolve legacy model ids through native regi
t.is(provider.resolveModel('gpt-5-2025-08-07')?.id, 'gpt-5'); t.is(provider.resolveModel('gpt-5-2025-08-07')?.id, 'gpt-5');
}); });
const BYOK_OPENAI_PROFILE: CopilotProviderProfile = {
id: 'byok-aaaaaaaaaaaa-openai-server-key1',
type: CopilotProviderType.OpenAI,
priority: 10_000,
config: { apiKey: 'byok-key' },
};
const BYOK_FAL_PROFILE: CopilotProviderProfile = {
id: 'byok-aaaaaaaaaaaa-fal-server-key1',
type: CopilotProviderType.FAL,
priority: 10_000,
config: { apiKey: 'byok-key' },
};
function createProviderFactoryWithByokRoutes({
byokProfiles = [BYOK_OPENAI_PROFILE],
hasQuota = true,
}: {
byokProfiles?: CopilotProviderProfile[];
hasQuota?: boolean;
} = {}) {
const provider = createProvider();
const registryService = {
getRegistry: () =>
buildProviderRegistry({
profiles: [
{
id: 'openai-main',
type: CopilotProviderType.OpenAI,
priority: 1,
config: { apiKey: 'test-key' },
},
],
defaults: {},
}),
};
const server = {
enableFeature: Sinon.stub(),
disableFeature: Sinon.stub(),
};
const byok = {
getProfiles: Sinon.stub().resolves(byokProfiles),
};
const access = {
resolveRouteAccess: Sinon.stub().callsFake(async context => ({
byokProfiles: await byok.getProfiles(context),
quotaBackedRoutesAvailable: context.quotaBackedRoutesAllowed ?? hasQuota,
})),
};
const factory = new CopilotProviderFactory(
server as never,
registryService as never,
access as never
);
factory.register('openai-main', provider);
return { factory, byok };
}
test('CopilotProviderFactory should use matching BYOK routes before quota-backed routes', async t => {
const { factory } = createProviderFactoryWithByokRoutes();
const routes = await factory.resolveRoutes(
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
{},
{ userId: 'user-1', workspaceId: 'workspace-1' }
);
t.deepEqual(
routes.map(route => route.providerId),
['byok-aaaaaaaaaaaa-openai-server-key1']
);
});
test('CopilotProviderFactory should skip unsupported BYOK profiles and use quota-backed fallback', async t => {
const { factory } = createProviderFactoryWithByokRoutes({
byokProfiles: [BYOK_FAL_PROFILE],
});
const routes = await factory.resolveRoutes(
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
{},
{ userId: 'user-1', workspaceId: 'workspace-1' }
);
t.deepEqual(
routes.map(route => route.providerId),
['openai-main']
);
});
test('CopilotProviderFactory should resolve BYOK embedding routes with workspace context', async t => {
const { factory, byok } = createProviderFactoryWithByokRoutes();
const routes = await factory.resolveRoutes(
{
modelId: 'text-embedding-3-small',
outputType: ModelOutputType.Embedding,
},
{},
{ workspaceId: 'workspace-1', featureKind: 'workspace_indexing' }
);
t.deepEqual(
routes.map(route => route.providerId),
['byok-aaaaaaaaaaaa-openai-server-key1']
);
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
workspaceId: 'workspace-1',
featureKind: 'workspace_indexing',
});
});
test('CopilotProviderFactory should treat embedding preparation as embedding feature by default', async t => {
const { factory, byok } = createProviderFactoryWithByokRoutes();
await factory.prepareEmbeddingRoutes('text-embedding-3-small', 'hello', {
workspace: 'workspace-1',
});
t.true(byok.getProfiles.calledOnce);
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
workspaceId: 'workspace-1',
featureKind: 'embedding',
});
});
test('CopilotProviderFactory should resolve BYOK rerank routes before quota-backed routes', async t => {
const { factory, byok } = createProviderFactoryWithByokRoutes();
const preparedRoutes = await factory.prepareRerankRoutes(
'gpt-4o-mini',
{
query: 'programming',
candidates: [{ text: 'React is a UI library.' }],
},
{ workspace: 'workspace-1' }
);
const resolvedRoutes = await factory.resolveRoutes(
{ modelId: 'gpt-4o-mini', outputType: ModelOutputType.Rerank },
{},
{ workspaceId: 'workspace-1', featureKind: 'rerank' }
);
t.deepEqual(
preparedRoutes.map(route => route.providerId),
[]
);
t.deepEqual(
resolvedRoutes.map(route => route.providerId),
['byok-aaaaaaaaaaaa-openai-server-key1']
);
Sinon.assert.calledWithMatch(byok.getProfiles, {
workspaceId: 'workspace-1',
featureKind: 'rerank',
});
});
test('CopilotProviderFactory should treat image preparation as image feature by default', async t => {
const { factory, byok } = createProviderFactoryWithByokRoutes();
await factory.prepareImageRoutes(
{ modelId: 'gpt-image-1', outputType: ModelOutputType.Image },
singleUserPromptMessages('draw a cat'),
{ workspace: 'workspace-1' }
);
t.true(byok.getProfiles.calledOnce);
Sinon.assert.calledOnceWithMatch(byok.getProfiles, {
workspaceId: 'workspace-1',
featureKind: 'image',
});
});
test('CopilotProviderFactory should omit quota-backed routes when quota is exhausted', async t => {
const { factory } = createProviderFactoryWithByokRoutes({ hasQuota: false });
const routes = await factory.resolveRoutes(
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
{},
{ userId: 'user-1', workspaceId: 'workspace-1' }
);
t.deepEqual(
routes.map(route => route.providerId),
['byok-aaaaaaaaaaaa-openai-server-key1']
);
});
test('CopilotProviderFactory should raise quota exceeded when only quota-backed routes match', async t => {
const { factory } = createProviderFactoryWithByokRoutes({
byokProfiles: [],
hasQuota: false,
});
await t.throwsAsync(
factory.resolveRoutes(
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
{},
{ userId: 'user-1', workspaceId: 'workspace-1' }
),
{ instanceOf: CopilotQuotaExceeded }
);
});
test('CopilotProviderFactory should not report quota exhausted when quota-backed routes are disabled', async t => {
const { factory } = createProviderFactoryWithByokRoutes({
byokProfiles: [],
hasQuota: true,
});
const routes = await factory.resolveRoutes(
{ modelId: 'gpt-5-mini', outputType: ModelOutputType.Text },
{},
{
userId: 'user-1',
workspaceId: 'workspace-1',
quotaBackedRoutesAllowed: false,
}
);
t.deepEqual(routes, []);
});
test('selectModel should reject unknown models without online fallback', t => { test('selectModel should reject unknown models without online fallback', t => {
const provider = new TestOpenAIProvider(); const provider = new TestOpenAIProvider();
t.is(provider.resolveModel('online-preview'), undefined); t.is(provider.resolveModel('online-preview'), undefined);
@@ -1476,7 +1954,7 @@ test('ProviderDriverSpec should freeze declarative driver shape', t => {
}); });
test('NativeExecutionEngine should dispatch prepared text routes through native fallback', async t => { test('NativeExecutionEngine should dispatch prepared text routes through native fallback', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
const registry = buildProviderRegistry({ const registry = buildProviderRegistry({
profiles: [ profiles: [
{ {
@@ -1575,8 +2053,78 @@ test('NativeExecutionEngine should dispatch prepared text routes through native
t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes)); t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes));
}); });
test('NativeExecutionEngine should record single BYOK route dispatch failure', async t => {
const byok = {
recordProviderFailure: Sinon.stub().resolves(),
recordUsage: Sinon.stub().resolves(),
};
const engine = new NativeExecutionEngine(byok as never);
const providerId = 'byok-aaaaaaaaaaaa-openai-server-key1';
const original = (serverNativeModule as any).llmDispatchPrepared;
(serverNativeModule as any).llmDispatchPrepared = () => {
throw new Error('401 invalid sk-test-primary');
};
t.teardown(() => {
(serverNativeModule as any).llmDispatchPrepared = original;
});
const error = await t.throwsAsync(
engine.execute({
nativeDispatch: {
chat: {
routes: [
nativeRoute({
providerId,
authToken: 'primary-key',
request: nativeTextRequest('hello'),
}),
],
prepared: {
route: preparedRoute({
providerId,
authToken: 'primary-key',
}),
request: nativeTextRequest('hello'),
tools: {},
postprocess: { nodeTextMiddleware: [] },
},
hasTools: false,
},
},
request: {
kind: 'text',
cond: { modelId: 'gpt-5-mini' },
messages: singleUserPromptMessages('hello'),
options: {
workspace: 'workspace-1',
user: 'user-1',
session: 'session-1',
featureKind: 'chat',
},
},
routePolicy: { fallbackOrder: [providerId] },
runtimePolicy: {},
attachmentPolicy: { materializeRemoteAttachments: true },
responsePostprocess: { mode: 'text' },
hostPersistence: { persistAssistantTurn: true, outputKind: 'text' },
hostContext: {
currentMessages: singleUserPromptMessages('hello'),
},
})
);
t.truthy(error);
Sinon.assert.calledOnceWithMatch(byok.recordProviderFailure, {
workspaceId: 'workspace-1',
providerId,
featureKind: 'chat',
});
Sinon.assert.notCalled(byok.recordUsage);
});
test('NativeExecutionEngine should reject single-route plans when no native route is prepared', async t => { test('NativeExecutionEngine should reject single-route plans when no native route is prepared', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
const error = await t.throwsAsync( const error = await t.throwsAsync(
engine.execute({ engine.execute({
@@ -1604,7 +2152,7 @@ test('NativeExecutionEngine should reject single-route plans when no native rout
}); });
test('NativeExecutionEngine should prefer prepared native fallback dispatch for explicit routes', async t => { test('NativeExecutionEngine should prefer prepared native fallback dispatch for explicit routes', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
let called = false; let called = false;
@@ -1683,7 +2231,7 @@ test('NativeExecutionEngine should prefer prepared native fallback dispatch for
}); });
test('NativeExecutionEngine should stream through prepared native fallback dispatch', async t => { test('NativeExecutionEngine should stream through prepared native fallback dispatch', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let called = false; let called = false;
const original = (serverNativeModule as any).llmDispatchPreparedStream; const original = (serverNativeModule as any).llmDispatchPreparedStream;
@@ -1912,7 +2460,7 @@ test('ExecutionPlanBuilder should keep single-route tool chat plans on prepared_
}); });
test('NativeExecutionEngine should route tool-loop chat prepared routes through native dispatch', async t => { test('NativeExecutionEngine should route tool-loop chat prepared routes through native dispatch', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
let called = false; let called = false;
let toolCallbackCount = 0; let toolCallbackCount = 0;
@@ -2262,7 +2810,7 @@ test('ExecutionPlanBuilder should build native prepared routes for structured, i
}); });
test('NativeExecutionEngine should dispatch structured prepared routes through native execution', async t => { test('NativeExecutionEngine should dispatch structured prepared routes through native execution', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
let called = false; let called = false;
@@ -2350,7 +2898,7 @@ test('NativeExecutionEngine should dispatch structured prepared routes through n
}); });
test('NativeExecutionEngine should dispatch embedding prepared routes through native execution', async t => { test('NativeExecutionEngine should dispatch embedding prepared routes through native execution', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
let called = false; let called = false;
@@ -2424,7 +2972,7 @@ test('NativeExecutionEngine should dispatch embedding prepared routes through na
}); });
test('NativeExecutionEngine should dispatch rerank prepared routes through native execution', async t => { test('NativeExecutionEngine should dispatch rerank prepared routes through native execution', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
let called = false; let called = false;
@@ -2507,7 +3055,7 @@ test('NativeExecutionEngine should dispatch rerank prepared routes through nativ
}); });
test('NativeExecutionEngine should dispatch image plans through prepared native routes', async t => { test('NativeExecutionEngine should dispatch image plans through prepared native routes', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
let capturedRoutes: unknown; let capturedRoutes: unknown;
const original = (serverNativeModule as any).llmImageDispatchPrepared; const original = (serverNativeModule as any).llmImageDispatchPrepared;
(serverNativeModule as any).llmImageDispatchPrepared = ( (serverNativeModule as any).llmImageDispatchPrepared = (
@@ -2587,8 +3135,90 @@ test('NativeExecutionEngine should dispatch image plans through prepared native
t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes)); t.snapshot(summarizePreparedDispatchRoutes(capturedRoutes));
}); });
test('NativeExecutionEngine should record zero-token BYOK image usage without provider usage', async t => {
const byok = {
recordUsage: Sinon.stub().resolves(),
};
const engine = new NativeExecutionEngine(byok as never);
const providerId = 'byok-aaaaaaaaaaaa-fal-server-key1';
const original = (serverNativeModule as any).llmImageDispatchPrepared;
(serverNativeModule as any).llmImageDispatchPrepared = () => {
return JSON.stringify({
provider_id: providerId,
response: {
images: [
{
url: 'https://cdn.example.com/image.png',
media_type: 'image/png',
},
],
},
});
};
t.teardown(() => {
(serverNativeModule as any).llmImageDispatchPrepared = original;
});
const request = nativeImageRequest('draw a cat');
const imageArtifacts = await collectAsync(
engine.executeImageArtifacts({
nativeDispatch: {
image: {
routes: [
nativeRoute({
providerId,
authToken: 'image-key',
protocol: 'fal_image',
model: 'fal-ai/fast-sdxl',
request,
}),
],
prepared: {
route: preparedRoute({
providerId,
authToken: 'image-key',
protocol: 'fal_image',
model: 'fal-ai/fast-sdxl',
}),
request,
},
},
},
request: {
kind: 'image',
cond: { modelId: 'fal-ai/fast-sdxl' },
messages: singleUserPromptMessages('draw a cat'),
options: {
workspace: 'workspace-1',
user: 'user-1',
session: 'session-1',
featureKind: 'image',
},
},
routePolicy: { fallbackOrder: [providerId] },
runtimePolicy: {},
attachmentPolicy: { materializeRemoteAttachments: true },
responsePostprocess: { mode: 'image' },
hostPersistence: { persistAssistantTurn: true, outputKind: 'image' },
hostContext: {},
})
);
t.is(imageArtifacts.length, 1);
Sinon.assert.calledOnceWithMatch(byok.recordUsage, {
workspaceId: 'workspace-1',
userId: 'user-1',
sessionId: 'session-1',
featureKind: 'image',
providerId,
model: 'fal-ai/fast-sdxl',
usage: undefined,
});
});
test('NativeExecutionEngine should reject image plans without native dispatch', async t => { test('NativeExecutionEngine should reject image plans without native dispatch', async t => {
const engine = new NativeExecutionEngine(); const engine = createNativeExecutionEngine();
await t.throwsAsync( await t.throwsAsync(
collectAsync( collectAsync(

View File

@@ -534,6 +534,58 @@ test('doc_semantic_search should return empty array when nothing matches', async
t.deepEqual(result, []); t.deepEqual(result, []);
}); });
test('doc_semantic_search should pass BYOK route context into embedding matches', async t => {
const ac = {
user: () => ({
workspace: () => ({
can: async () => true,
docs: async () => [],
}),
}),
} as unknown as AccessController;
const models = {
workspace: {
get: async () => ({ id: 'workspace-1' }),
},
} as unknown as Models;
let workspaceRouteContext: unknown;
let sessionRouteContext: unknown;
const contextService = {
matchWorkspaceAll: async (...args: unknown[]) => {
workspaceRouteContext = args[7];
return [];
},
getBySessionId: async () => ({
matchFiles: async (...args: unknown[]) => {
sessionRouteContext = args[5];
return [];
},
}),
} as unknown as Parameters<typeof buildDocSearchGetter>[1];
const semanticTool = createDocSemanticSearchTool(
buildDocSearchGetter(ac, contextService, 'session-1', models).bind(null, {
user: 'user-1',
workspace: 'workspace-1',
byokLeaseId: 'lease-1',
})
);
const result = await semanticTool.execute?.({ query: 'hello' }, {});
t.deepEqual(result, []);
t.deepEqual(workspaceRouteContext, {
userId: 'user-1',
byokLeaseId: 'lease-1',
});
t.deepEqual(sessionRouteContext, {
userId: 'user-1',
byokLeaseId: 'lease-1',
});
});
test('blob_read should return explicit error when attachment context is missing', async t => { test('blob_read should return explicit error when attachment context is missing', async t => {
const ac = { const ac = {
user: () => ({ user: () => ({

View File

@@ -215,7 +215,7 @@ test('settleTask checks copilot quota before unlocking ready task', async t => {
status: 'settled', status: 'settled',
protectedResult: payload, protectedResult: payload,
}); });
const checkQuota = Sinon.stub().rejects(new Error('quota exceeded')); const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
const service = new CopilotTranscriptionService( const service = new CopilotTranscriptionService(
{ {
copilotTranscriptTask: { copilotTranscriptTask: {
@@ -232,14 +232,18 @@ test('settleTask checks copilot quota before unlocking ready task', async t => {
{} as never, {} as never,
{} as never, {} as never,
{} as never, {} as never,
{ checkQuota } as never { assertQuotaOrByok } as never
); );
await t.throwsAsync( await t.throwsAsync(
() => service.settleTask('user-1', 'workspace-1', 'task-1'), () => service.settleTask('user-1', 'workspace-1', 'task-1'),
{ message: /quota exceeded/ } { message: /quota exceeded/ }
); );
Sinon.assert.calledOnceWithExactly(checkQuota, 'user-1'); Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
userId: 'user-1',
workspaceId: 'workspace-1',
featureKind: 'transcript',
});
Sinon.assert.notCalled(settle); Sinon.assert.notCalled(settle);
}); });
@@ -341,6 +345,48 @@ test('retryTask reuses failed task and queues a new action attempt', async t =>
Sinon.assert.calledOnceWithExactly(markRunning, 'task-1'); Sinon.assert.calledOnceWithExactly(markRunning, 'task-1');
}); });
test('retryTask prechecks quota or BYOK before queueing provider work', async t => {
const add = Sinon.stub().resolves(undefined);
const markRunning = Sinon.stub().resolves({ id: 'task-1' });
const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
const payload = TranscriptPayloadSchema.parse({
normalizedTranscript: '00:00:05 A: Kickoff',
});
const service = new CopilotTranscriptionService(
{
copilotTranscriptTask: {
getWithUser: Sinon.stub().resolves({
id: 'task-1',
status: 'failed',
strategy: 'gemini',
protectedResult: payload,
}),
markRunning,
},
} as never,
{ add } as never,
{} as never,
{
resolveTranscriptionModel: Sinon.stub().resolves('gemini-2.5-flash'),
} as never,
{} as never,
{} as never,
{ assertQuotaOrByok } as never
);
await t.throwsAsync(
() => service.retryTask('user-1', 'workspace-1', 'task-1'),
{ message: /quota exceeded/ }
);
Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
userId: 'user-1',
workspaceId: 'workspace-1',
featureKind: 'transcript',
});
Sinon.assert.notCalled(add);
Sinon.assert.notCalled(markRunning);
});
for (const status of ['ready', 'settled']) { for (const status of ['ready', 'settled']) {
test(`submitTask allows a new task for the same blob after ${status} task`, async t => { test(`submitTask allows a new task for the same blob after ${status} task`, async t => {
const createdTasks: unknown[] = []; const createdTasks: unknown[] = [];
@@ -390,6 +436,37 @@ for (const status of ['ready', 'settled']) {
}); });
} }
test('submitTask prechecks quota or BYOK before persisting uploads', async t => {
const assertQuotaOrByok = Sinon.stub().rejects(new Error('quota exceeded'));
const resolveTranscriptionModel = Sinon.stub().resolves('gemini-2.5-flash');
const service = new CopilotTranscriptionService(
{
copilotTranscriptTask: {
getWithUser: Sinon.stub().resolves(null),
},
} as never,
{} as never,
{} as never,
{
resolveTranscriptionModel,
} as never,
{} as never,
{} as never,
{ assertQuotaOrByok } as never
);
await t.throwsAsync(
() => service.submitTask('user-1', 'workspace-1', 'blob-1', []),
{ message: /quota exceeded/ }
);
Sinon.assert.calledOnceWithMatch(assertQuotaOrByok, {
userId: 'user-1',
workspaceId: 'workspace-1',
featureKind: 'transcript',
});
Sinon.assert.notCalled(resolveTranscriptionModel);
});
test('submitTask rejects unavailable transcript strategy', async t => { test('submitTask rejects unavailable transcript strategy', async t => {
const service = new CopilotTranscriptionService( const service = new CopilotTranscriptionService(
{ {

View File

@@ -1145,6 +1145,110 @@ test('should count action runs without double-counting legacy action sessions',
t.truthy(legacyAction.sessionId); t.truthy(legacyAction.sessionId);
}); });
test('should exclude BYOK provider usage from copilot quota cost', async t => {
const { copilotSession, db, models } = t.context;
await createTestPrompts(copilotSession, db);
const regular = await createTestSession(t);
const firstMessage = await copilotSession.appendMessage({
sessionId: regular.sessionId,
userId: user.id,
prompt: { model: 'gpt-5-mini' },
message: {
role: 'user',
content: 'regular message',
createdAt: new Date(),
},
});
const secondMessage = await copilotSession.appendMessage({
sessionId: regular.sessionId,
userId: user.id,
prompt: { model: 'gpt-5-mini' },
message: {
role: 'user',
content: 'second BYOK message',
createdAt: new Date(),
},
});
await copilotSession.appendMessage({
sessionId: regular.sessionId,
userId: user.id,
prompt: { model: 'gpt-5-mini' },
message: {
role: 'user',
content: 'quota-backed message',
createdAt: new Date(),
},
});
const failedRun = await models.copilotActionRun.create({
userId: user.id,
workspaceId: workspace.id,
actionId: 'mindmap.generate',
actionVersion: 'v1',
});
await models.copilotActionRun.complete(failedRun.id, {
status: 'failed',
errorCode: 'test_failed',
});
const pendingTranscriptTask = await models.copilotTranscriptTask.create({
userId: user.id,
workspaceId: workspace.id,
blobId: 'pending-audio',
strategy: 'gemini',
recipeId: 'transcript.audio.gemini',
recipeVersion: 'v1',
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'openai',
providerSource: 'byok_server',
featureKind: 'chat',
billingUnitId: firstMessage.id,
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'openai',
providerSource: 'byok_server',
featureKind: 'chat',
billingUnitId: firstMessage.id,
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'fal',
providerSource: 'byok_server',
featureKind: 'image',
billingUnitId: secondMessage.id,
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'fal',
providerSource: 'byok_server',
featureKind: 'image',
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'openai',
providerSource: 'byok_server',
featureKind: 'action',
billingUnitId: failedRun.id,
});
await models.copilotUsage.create({
workspaceId: workspace.id,
userId: user.id,
provider: 'gemini',
providerSource: 'byok_server',
featureKind: 'transcript',
billingUnitId: pendingTranscriptTask.id,
});
t.is(await copilotSession.countUserMessages(user.id), 1);
});
test('should get sessions for title generation correctly', async t => { test('should get sessions for title generation correctly', async t => {
const { copilotSession, db } = t.context; const { copilotSession, db } = t.context;
await createTestPrompts(copilotSession, db); await createTestPrompts(copilotSession, db);

View File

@@ -5,8 +5,9 @@ import ava, { ExecutionContext, TestFn } from 'ava';
import Sinon from 'sinon'; import Sinon from 'sinon';
import { Doc as YDoc } from 'yjs'; import { Doc as YDoc } from 'yjs';
import { MockEventBus } from '../../../__tests__/mocks';
import { createTestingApp, type TestingApp } from '../../../__tests__/utils'; import { createTestingApp, type TestingApp } from '../../../__tests__/utils';
import { ConfigFactory } from '../../../base'; import { ConfigFactory, EventBus } from '../../../base';
import { Flavor } from '../../../env'; import { Flavor } from '../../../env';
import { Models } from '../../../models'; import { Models } from '../../../models';
import { DocReader, PgWorkspaceDocStorageAdapter } from '../../doc'; import { DocReader, PgWorkspaceDocStorageAdapter } from '../../doc';
@@ -16,6 +17,7 @@ interface Context {
app: TestingApp; app: TestingApp;
adapter: PgWorkspaceDocStorageAdapter; adapter: PgWorkspaceDocStorageAdapter;
docReader: DocReader; docReader: DocReader;
recordDocView: Sinon.SinonStub;
} }
const test = ava as TestFn<Context>; const test = ava as TestFn<Context>;
@@ -23,7 +25,9 @@ const test = ava as TestFn<Context>;
test.before(async t => { test.before(async t => {
// @ts-expect-error testing // @ts-expect-error testing
env.FLAVOR = Flavor.Renderer; env.FLAVOR = Flavor.Renderer;
const app = await createTestingApp(); const app = await createTestingApp({
tapModule: m => m.overrideProvider(EventBus).useClass(MockEventBus),
});
t.context.models = app.get(Models); t.context.models = app.get(Models);
t.context.adapter = app.get(PgWorkspaceDocStorageAdapter); t.context.adapter = app.get(PgWorkspaceDocStorageAdapter);
@@ -45,6 +49,14 @@ test.beforeEach(async t => {
email: 'test@affine.pro', email: 'test@affine.pro',
}); });
workspace = await t.context.models.workspace.create(user.id); workspace = await t.context.models.workspace.create(user.id);
t.context.recordDocView = Sinon.stub(
t.context.models.workspaceAnalytics,
'recordDocView'
).resolves();
});
test.afterEach.always(t => {
t.context.recordDocView?.restore();
}); });
test.after.always(async t => { test.after.always(async t => {
@@ -88,10 +100,7 @@ test('should record page view when rendering shared page', async t => {
title: 'analytics-doc', title: 'analytics-doc',
summary: 'summary', summary: 'summary',
}); });
const record = Sinon.stub( const record = t.context.recordDocView;
models.workspaceAnalytics,
'recordDocView'
).resolves();
await app.GET(`/workspace/${workspace.id}/${docId}`).expect(200); await app.GET(`/workspace/${workspace.id}/${docId}`).expect(200);
@@ -103,7 +112,6 @@ test('should record page view when rendering shared page', async t => {
}); });
docContent.restore(); docContent.restore();
record.restore();
}); });
const policyCases: Array<{ const policyCases: Array<{
@@ -146,10 +154,7 @@ const policyCases: Array<{
unknownBlocks: [], unknownBlocks: [],
}), }),
docContent: Sinon.stub(docReader, 'getDocContent'), docContent: Sinon.stub(docReader, 'getDocContent'),
record: Sinon.stub( record: models.workspaceAnalytics.recordDocView as Sinon.SinonStub,
models.workspaceAnalytics,
'recordDocView'
).resolves(),
}; };
}, },
request: (app, docId) => request: (app, docId) =>

View File

@@ -0,0 +1,139 @@
import { Injectable } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { BaseModel } from './base';
export type UpsertAiWorkspaceByokConfigInput = {
id?: string | null;
workspaceId: string;
provider: string;
name: string;
description: string | null;
encryptedApiKey?: string;
endpoint: string | null;
sortOrder: number;
enabled: boolean;
userId?: string;
};
@Injectable()
export class CopilotWorkspaceByokConfigModel extends BaseModel {
async list(workspaceId: string) {
return await this.db.aiWorkspaceByokConfig.findMany({
where: { workspaceId },
orderBy: [{ sortOrder: 'asc' }, { createdAt: 'asc' }],
});
}
async listEnabled(workspaceId: string) {
return await this.db.aiWorkspaceByokConfig.findMany({
where: { workspaceId, enabled: true },
orderBy: [{ sortOrder: 'asc' }, { createdAt: 'asc' }],
});
}
async get(id: string) {
return await this.db.aiWorkspaceByokConfig.findUnique({
where: { id },
});
}
@Transactional()
async upsert(input: UpsertAiWorkspaceByokConfigInput) {
const data = {
provider: input.provider,
name: input.name,
description: input.description,
endpoint: input.endpoint,
sortOrder: input.sortOrder,
enabled: input.enabled,
updatedBy: input.userId,
...(input.encryptedApiKey
? {
encryptedApiKey: input.encryptedApiKey,
lastValidatedAt: new Date(),
lastValidationError: null,
disabledReason: null,
lastError: null,
lastErrorAt: null,
}
: {}),
};
return input.id
? await this.db.aiWorkspaceByokConfig.update({
where: { id: input.id, workspaceId: input.workspaceId },
data,
})
: await this.db.aiWorkspaceByokConfig.create({
data: {
...data,
encryptedApiKey: input.encryptedApiKey ?? '',
workspaceId: input.workspaceId,
createdBy: input.userId,
},
});
}
@Transactional()
async reorder(workspaceId: string, ids: string[], userId?: string) {
await Promise.all(
ids.map((id, sortOrder) =>
this.db.aiWorkspaceByokConfig.update({
where: { id, workspaceId },
data: { sortOrder, updatedBy: userId },
})
)
);
}
@Transactional()
async delete(workspaceId: string, id: string) {
await this.db.aiWorkspaceByokConfig.delete({ where: { id, workspaceId } });
}
@Transactional()
async clear(workspaceId: string, provider?: string | null) {
await this.db.aiWorkspaceByokConfig.deleteMany({
where: { workspaceId, ...(provider ? { provider } : {}) },
});
}
@Transactional()
async markValidated(workspaceId: string, id: string, userId?: string) {
await this.db.aiWorkspaceByokConfig.update({
where: { id, workspaceId },
data: {
enabled: true,
disabledReason: null,
lastValidatedAt: new Date(),
lastValidationError: null,
lastError: null,
lastErrorAt: null,
updatedBy: userId,
},
});
}
@Transactional()
async markFailure(workspaceId: string, id: string, message: string) {
await this.db.aiWorkspaceByokConfig.update({
where: { id, workspaceId },
data: {
enabled: false,
disabledReason: 'recent_failure',
lastValidationError: message,
lastError: message,
lastErrorAt: new Date(),
},
});
}
@Transactional()
async touchUsed(workspaceId: string, id: string) {
await this.db.aiWorkspaceByokConfig.updateMany({
where: { id, workspaceId },
data: { lastUsedAt: new Date() },
});
}
}

View File

@@ -1005,20 +1005,26 @@ export class CopilotSessionModel extends BaseModel {
.filter(({ promptAction }) => !promptAction) .filter(({ promptAction }) => !promptAction)
.map(({ messageCost }) => messageCost) .map(({ messageCost }) => messageCost)
.reduce((prev, cost) => prev + cost, 0); .reduce((prev, cost) => prev + cost, 0);
const [actionRunCost, legacyActionSessionCost, transcriptSettlementCost] = const [
await Promise.all([ actionRunCost,
this.models.copilotActionRun.countSucceededByUser(userId), legacyActionSessionCost,
this.models.copilotActionRun.countLegacyPromptActionSessionsWithoutRun( transcriptSettlementCost,
userId byokQuotaExemptCost,
), ] = await Promise.all([
this.models.copilotTranscriptTask.countSettledByUser(userId), this.models.copilotActionRun.countSucceededByUser(userId),
]); this.models.copilotActionRun.countLegacyPromptActionSessionsWithoutRun(
return ( userId
),
this.models.copilotTranscriptTask.countSettledByUser(userId),
this.models.copilotUsage.countQuotaExemptByokUsage(userId),
]);
const quotaBackedCost =
regularMessageCost + regularMessageCost +
actionRunCost + actionRunCost +
legacyActionSessionCost + legacyActionSessionCost +
transcriptSettlementCost transcriptSettlementCost -
); byokQuotaExemptCost;
return Math.max(0, quotaBackedCost);
} }
async cleanupEmptySessions(earlyThen: Date) { async cleanupEmptySessions(earlyThen: Date) {

View File

@@ -0,0 +1,142 @@
import { Injectable } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { Prisma } from '@prisma/client';
import { BaseModel } from './base';
type CreateAiUsageEventInput = {
workspaceId: string;
userId?: string;
provider: string;
providerSource: string;
featureKind: string;
model?: string | null;
sessionId?: string;
taskId?: string;
actionId?: string;
billingUnitId?: string;
promptTokens?: number;
completionTokens?: number;
totalTokens?: number;
cachedTokens?: number;
};
type UsageAggregateRow = {
date: string;
featureKind: string;
totalTokens: number | bigint | null;
};
type CountRow = {
count: number | bigint;
};
const BYOK_PROVIDER_SOURCES = ['byok_server', 'byok_local'];
const QUOTA_EXEMPT_BYOK_FEATURES = ['chat', 'action', 'image', 'transcript'];
@Injectable()
export class CopilotUsageModel extends BaseModel {
@Transactional()
async create(input: CreateAiUsageEventInput) {
await this.db.aiUsageEvent.create({
data: {
workspaceId: input.workspaceId,
userId: input.userId,
provider: input.provider,
providerSource: input.providerSource,
featureKind: input.featureKind,
model: input.model ?? null,
sessionId: input.sessionId,
taskId: input.taskId,
actionId: input.actionId,
billingUnitId: input.billingUnitId,
promptTokens: input.promptTokens ?? 0,
completionTokens: input.completionTokens ?? 0,
totalTokens: input.totalTokens ?? 0,
cachedTokens: input.cachedTokens ?? 0,
},
});
}
async countQuotaExemptByokUsage(userId: string) {
const rows = await this.db.$queryRaw<CountRow[]>(Prisma.sql`
WITH "byok_usage" AS (
SELECT "billing_unit_id", "feature_kind"
FROM "ai_usage_events"
WHERE "user_id" = ${userId}
AND "provider_source" IN (${Prisma.join(BYOK_PROVIDER_SOURCES)})
AND "feature_kind" IN (${Prisma.join(QUOTA_EXEMPT_BYOK_FEATURES)})
AND "billing_unit_id" IS NOT NULL
),
"message_units" AS (
SELECT DISTINCT "usage"."billing_unit_id"
FROM "byok_usage" AS "usage"
JOIN "ai_sessions_messages" AS "message"
ON "message"."id" = "usage"."billing_unit_id"
JOIN "ai_sessions_metadata" AS "session"
ON "session"."id" = "message"."session_id"
WHERE "usage"."feature_kind" IN ('chat', 'action', 'image')
AND "message"."role" = 'user'
AND "session"."user_id" = ${userId}
AND ("session"."prompt_action" IS NULL OR "session"."prompt_action" = '')
),
"action_units" AS (
SELECT DISTINCT "usage"."billing_unit_id"
FROM "byok_usage" AS "usage"
JOIN "ai_action_runs" AS "run"
ON "run"."id" = "usage"."billing_unit_id"
WHERE "usage"."feature_kind" IN ('action', 'image')
AND "run"."user_id" = ${userId}
AND "run"."status" = 'succeeded'
AND "run"."action_id" NOT LIKE 'transcript.audio.%'
),
"transcript_units" AS (
SELECT DISTINCT "usage"."billing_unit_id"
FROM "byok_usage" AS "usage"
JOIN "ai_transcript_tasks" AS "task"
ON "task"."id" = "usage"."billing_unit_id"
WHERE "usage"."feature_kind" = 'transcript'
AND "task"."user_id" = ${userId}
AND "task"."status" = 'settled'
)
SELECT (
(SELECT COUNT(*) FROM "message_units") +
(SELECT COUNT(*) FROM "action_units") +
(SELECT COUNT(*) FROM "transcript_units")
) AS "count"
`);
const count = rows[0]?.count ?? 0;
return typeof count === 'bigint' ? Number(count) : count;
}
async aggregateByDay(input: {
workspaceId: string;
from: Date;
to: Date;
providerSources: string[];
}) {
if (!input.providerSources.length) return [];
const rows = await this.db.$queryRaw<UsageAggregateRow[]>(Prisma.sql`
SELECT
to_char(date_trunc('day', "created_at" AT TIME ZONE 'UTC'), 'YYYY-MM-DD') AS "date",
"feature_kind" AS "featureKind",
COALESCE(SUM("total_tokens"), 0)::bigint AS "totalTokens"
FROM "ai_usage_events"
WHERE "workspace_id" = ${input.workspaceId}
AND "provider_source" IN (${Prisma.join(input.providerSources)})
AND "created_at" >= ${input.from}
AND "created_at" < ${input.to}
GROUP BY 1, 2
ORDER BY 1 ASC, 2 ASC
`);
return rows.map(row => {
return {
date: new Date(`${row.date}T00:00:00.000Z`),
featureKind: row.featureKind,
totalTokens: Number(row.totalTokens ?? 0),
};
});
}
}

View File

@@ -17,10 +17,12 @@ import { CommentModel } from './comment';
import { CommentAttachmentModel } from './comment-attachment'; import { CommentAttachmentModel } from './comment-attachment';
import { AppConfigModel } from './config'; import { AppConfigModel } from './config';
import { CopilotActionRunModel } from './copilot-action-run'; import { CopilotActionRunModel } from './copilot-action-run';
import { CopilotWorkspaceByokConfigModel } from './copilot-byok';
import { CopilotContextModel } from './copilot-context'; import { CopilotContextModel } from './copilot-context';
import { CopilotJobModel } from './copilot-job'; import { CopilotJobModel } from './copilot-job';
import { CopilotSessionModel } from './copilot-session'; import { CopilotSessionModel } from './copilot-session';
import { CopilotTranscriptTaskModel } from './copilot-transcript-task'; import { CopilotTranscriptTaskModel } from './copilot-transcript-task';
import { CopilotUsageModel } from './copilot-usage';
import { CopilotWorkspaceConfigModel } from './copilot-workspace'; import { CopilotWorkspaceConfigModel } from './copilot-workspace';
import { DocModel } from './doc'; import { DocModel } from './doc';
import { DocUserModel } from './doc-user'; import { DocUserModel } from './doc-user';
@@ -58,10 +60,12 @@ const MODELS = {
notification: NotificationModel, notification: NotificationModel,
userSettings: UserSettingsModel, userSettings: UserSettingsModel,
copilotSession: CopilotSessionModel, copilotSession: CopilotSessionModel,
copilotUsage: CopilotUsageModel,
copilotTranscriptTask: CopilotTranscriptTaskModel, copilotTranscriptTask: CopilotTranscriptTaskModel,
copilotActionRun: CopilotActionRunModel, copilotActionRun: CopilotActionRunModel,
copilotContext: CopilotContextModel, copilotContext: CopilotContextModel,
copilotWorkspace: CopilotWorkspaceConfigModel, copilotWorkspace: CopilotWorkspaceConfigModel,
copilotWorkspaceByokConfig: CopilotWorkspaceByokConfigModel,
copilotJob: CopilotJobModel, copilotJob: CopilotJobModel,
appConfig: AppConfigModel, appConfig: AppConfigModel,
comment: CommentModel, comment: CommentModel,
@@ -133,10 +137,12 @@ export * from './calendar-subscription';
export * from './comment'; export * from './comment';
export * from './comment-attachment'; export * from './comment-attachment';
export * from './common'; export * from './common';
export * from './copilot-byok';
export * from './copilot-context'; export * from './copilot-context';
export * from './copilot-job'; export * from './copilot-job';
export * from './copilot-session'; export * from './copilot-session';
export * from './copilot-transcript-task'; export * from './copilot-transcript-task';
export * from './copilot-usage';
export * from './copilot-workspace'; export * from './copilot-workspace';
export * from './doc'; export * from './doc';
export * from './doc-user'; export * from './doc-user';

View File

@@ -458,6 +458,7 @@ type LlmRerankResponse = {
export type LlmToolLoopStreamEvent = export type LlmToolLoopStreamEvent =
| { type: 'message_start'; id?: string; model?: string } | { type: 'message_start'; id?: string; model?: string }
| { type: 'provider_selected'; provider_id: string }
| { type: 'text_delta'; text: string } | { type: 'text_delta'; text: string }
| { type: 'reasoning_delta'; text: string } | { type: 'reasoning_delta'; text: string }
| { | {
@@ -537,7 +538,14 @@ function parseLlmEventJson(eventJson: string): LlmStreamEvent {
function parseLlmToolLoopStreamEvent( function parseLlmToolLoopStreamEvent(
eventJson: string eventJson: string
): LlmToolLoopStreamEvent { ): LlmToolLoopStreamEvent {
return parseToolLoopStreamEvent(parseLlmEventJson(eventJson)); const event = parseLlmEventJson(eventJson);
if (
event.type === 'provider_selected' &&
typeof event.provider_id === 'string'
) {
return event;
}
return parseToolLoopStreamEvent(event);
} }
export function llmMatchModelCapabilities( export function llmMatchModelCapabilities(

View File

@@ -0,0 +1,44 @@
import type { ByokFeatureKind } from '../byok/types';
export type ByokSourceCoverage = {
local: boolean;
server: boolean;
};
export type CopilotFeatureAccessRule = ByokSourceCoverage & {
quotaMetered: boolean;
};
const DEFAULT_BYOK_COVERAGE: ByokSourceCoverage = {
local: true,
server: true,
};
const DEFAULT_FEATURE_ACCESS: CopilotFeatureAccessRule = {
...DEFAULT_BYOK_COVERAGE,
quotaMetered: true,
};
const COPILOT_FEATURE_ACCESS: Partial<
Record<ByokFeatureKind, CopilotFeatureAccessRule>
> = {
transcript: { local: false, server: true, quotaMetered: true },
embedding: { local: false, server: true, quotaMetered: false },
workspace_indexing: { local: false, server: true, quotaMetered: false },
rerank: { local: false, server: true, quotaMetered: false },
};
export function getByokSourceCoverage(
featureKind?: ByokFeatureKind
): ByokSourceCoverage {
const access = getCopilotFeatureAccess(featureKind);
return { local: access.local, server: access.server };
}
export function getCopilotFeatureAccess(
featureKind?: ByokFeatureKind
): CopilotFeatureAccessRule {
return featureKind
? (COPILOT_FEATURE_ACCESS[featureKind] ?? DEFAULT_FEATURE_ACCESS)
: DEFAULT_FEATURE_ACCESS;
}

View File

@@ -0,0 +1,2 @@
export * from './feature-coverage';
export * from './policy';

View File

@@ -0,0 +1,106 @@
import { Injectable } from '@nestjs/common';
import { CopilotQuotaExceeded } from '../../../base';
import { ByokService } from '../byok/service';
import type { ByokFeatureKind } from '../byok/types';
import type { CopilotProviderProfile } from '../config';
import { ConversationPolicy } from '../conversation/policy';
import {
getByokSourceCoverage,
getCopilotFeatureAccess,
} from './feature-coverage';
export type CopilotAccessContext = {
userId?: string;
workspaceId?: string;
byokLeaseId?: string;
featureKind?: ByokFeatureKind;
quotaBackedRoutesAllowed?: boolean;
};
export type CopilotRouteAccess = {
byokProfiles: CopilotProviderProfile[];
quotaBackedRoutesAvailable: boolean;
};
export type CopilotTurnRouteAccess = {
byokProfiles: CopilotProviderProfile[];
quotaBackedRoutesAllowed?: boolean;
};
@Injectable()
export class CopilotAccessPolicy {
constructor(
private readonly conversationPolicy: ConversationPolicy,
private readonly byok: ByokService
) {}
async getByokProfiles(context: CopilotAccessContext = {}) {
const coverage = getByokSourceCoverage(context.featureKind);
return await this.byok.getProfiles(context, coverage);
}
async canUseQuotaBackedRoutes(context: CopilotAccessContext = {}) {
if (context.quotaBackedRoutesAllowed !== undefined) {
return context.quotaBackedRoutesAllowed;
}
if (!getCopilotFeatureAccess(context.featureKind).quotaMetered) {
return true;
}
if (!context.userId) {
return true;
}
return await this.conversationPolicy.hasQuota(context.userId);
}
async getQuota(userId: string) {
return await this.conversationPolicy.getQuota(userId);
}
async checkQuota(userId: string) {
await this.conversationPolicy.checkQuota(userId);
}
async resolveRouteAccess(
context: CopilotAccessContext = {}
): Promise<CopilotRouteAccess> {
const [byokProfiles, quotaBackedRoutesAvailable] = await Promise.all([
this.getByokProfiles(context),
this.canUseQuotaBackedRoutes(context),
]);
return { byokProfiles, quotaBackedRoutesAvailable };
}
async resolveTurnRouteAccess(
context: CopilotAccessContext
): Promise<CopilotTurnRouteAccess> {
const byokProfiles = await this.getByokProfiles(context);
if (context.quotaBackedRoutesAllowed === false) {
return { byokProfiles, quotaBackedRoutesAllowed: false };
}
const featureAccess = getCopilotFeatureAccess(context.featureKind);
if (!byokProfiles.length && context.userId && featureAccess.quotaMetered) {
await this.conversationPolicy.checkQuota(context.userId);
}
const quotaBackedRoutesAllowed = byokProfiles.length
? context.quotaBackedRoutesAllowed
: true;
return { byokProfiles, quotaBackedRoutesAllowed };
}
async assertQuotaOrByok(context: CopilotAccessContext) {
const byokProfiles = await this.getByokProfiles(context);
if (context.quotaBackedRoutesAllowed === false) {
if (!byokProfiles.length) {
throw new CopilotQuotaExceeded();
}
return;
}
const featureAccess = getCopilotFeatureAccess(context.featureKind);
if (!byokProfiles.length && context.userId && featureAccess.quotaMetered) {
await this.conversationPolicy.checkQuota(context.userId);
}
}
}

View File

@@ -0,0 +1,4 @@
export { ByokEntitlementPolicy } from './policy';
export { WorkspaceByokResolver } from './resolver';
export { type ByokProviderRequestContext, ByokService } from './service';
export * from './types';

View File

@@ -0,0 +1,105 @@
import { Injectable } from '@nestjs/common';
import { ActionForbidden } from '../../../base';
import { Models, WorkspaceRole } from '../../../models';
@Injectable()
export class ByokEntitlementPolicy {
constructor(private readonly models: Models) {}
private isUserPlanEntitled(features: string[]) {
return (
features.includes('pro_plan_v1') ||
features.includes('lifetime_pro_plan_v1') ||
features.includes('unlimited_copilot')
);
}
async hasAiPlan(userId?: string) {
if (!userId) return false;
const features = await this.models.userFeature.list(userId);
return this.isUserPlanEntitled(features);
}
async hasManagementAccess(workspaceId: string, userId?: string) {
if (!userId) return false;
const role = await this.models.workspaceUser.getActive(workspaceId, userId);
return (
role?.type === WorkspaceRole.Owner || role?.type === WorkspaceRole.Admin
);
}
async assertManagementAccess(workspaceId: string, userId?: string) {
if (!(await this.hasManagementAccess(workspaceId, userId))) {
throw new ActionForbidden(
'BYOK settings require workspace owner or admin.'
);
}
}
private async getWorkspaceOwnerId(workspaceId: string) {
const workspace = await this.models.workspace.get(workspaceId);
if (!workspace) {
return null;
}
try {
return (await this.models.workspaceUser.getOwner(workspaceId)).id;
} catch (error) {
if (
error instanceof Error &&
error.message === 'Workspace owner not found'
) {
return null;
}
throw error;
}
}
async hasLocalEntitlement(workspaceId: string, userId?: string) {
if (env.selfhosted) return true;
if (await this.models.workspaceFeature.has(workspaceId, 'team_plan_v1')) {
return true;
}
const ownerId = await this.getWorkspaceOwnerId(workspaceId);
if (!ownerId) return false;
if (await this.hasAiPlan(userId)) return true;
return await this.hasAiPlan(ownerId);
}
async hasServerEntitlement(workspaceId: string) {
if (env.selfhosted) return true;
if (await this.models.workspaceFeature.has(workspaceId, 'team_plan_v1')) {
return true;
}
const ownerId = await this.getWorkspaceOwnerId(workspaceId);
if (!ownerId) return false;
return await this.hasAiPlan(ownerId);
}
async hasEntitlement(workspaceId: string, userId?: string) {
const [serverEntitled, localEntitled] = await Promise.all([
this.hasServerEntitlement(workspaceId),
this.hasLocalEntitlement(workspaceId, userId),
]);
return [serverEntitled, localEntitled] as const;
}
async assertServerEntitled(workspaceId: string) {
if (!(await this.hasServerEntitlement(workspaceId))) {
throw new ActionForbidden('BYOK requires Pro, Team, or Believer.');
}
}
async assertLocalEntitled(workspaceId: string, userId?: string) {
if (!(await this.hasLocalEntitlement(workspaceId, userId))) {
throw new ActionForbidden('BYOK requires Pro, Team, or Believer.');
}
}
}

View File

@@ -0,0 +1,405 @@
import {
Args,
Field,
ID,
InputType,
Mutation,
ObjectType,
Parent,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { SafeIntResolver } from 'graphql-scalars';
import { Throttle } from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { AccessController } from '../../../core/permission';
import { WorkspaceType } from '../../../core/workspaces';
import { ByokEntitlementPolicy } from './policy';
import { ByokKeyConfig, ByokLocalLeaseProvider, ByokService } from './service';
import { ByokKeyStorage, ByokKeyTestStatus, ByokProvider } from './types';
@ObjectType()
export class WorkspaceByokKeyConfigType implements ByokKeyConfig {
@Field(() => ID)
id!: string;
@Field(() => ByokProvider)
provider!: ByokProvider;
@Field(() => String)
name!: string;
@Field(() => String, { nullable: true })
description!: string | null;
@Field(() => ByokKeyStorage)
storage!: ByokKeyStorage;
@Field(() => Boolean)
configured!: boolean;
@Field(() => Boolean)
enabled!: boolean;
@Field(() => String, { nullable: true })
endpoint!: string | null;
@Field(() => Boolean)
endpointEditable!: boolean;
@Field(() => SafeIntResolver)
sortOrder!: number;
@Field(() => [String])
capabilities!: string[];
@Field(() => ByokKeyTestStatus)
testStatus!: ByokKeyTestStatus;
@Field(() => String, { nullable: true })
disabledReason!: string | null;
@Field(() => Date, { nullable: true })
lastTestedAt!: Date | null;
@Field(() => String, { nullable: true })
lastTestError!: string | null;
@Field(() => Date, { nullable: true })
lastUsedAt!: Date | null;
@Field(() => Date, { nullable: true })
lastErrorAt!: Date | null;
@Field(() => String, { nullable: true })
lastError!: string | null;
}
@ObjectType()
class WorkspaceByokCapabilityWarningType {
@Field(() => String)
featureKind!: string;
@Field(() => String)
reason!: string;
@Field(() => [ByokProvider])
requiredProviders!: ByokProvider[];
}
@ObjectType()
class WorkspaceByokSettingsType {
@Field(() => String)
workspaceId!: string;
@Field(() => Boolean)
entitled!: boolean;
@Field(() => Boolean)
serverEntitled!: boolean;
@Field(() => Boolean)
localEntitled!: boolean;
@Field(() => [String])
entitlementRequired!: string[];
@Field(() => [WorkspaceByokKeyConfigType])
keys!: WorkspaceByokKeyConfigType[];
@Field(() => [ByokProvider])
allowedProviders!: ByokProvider[];
@Field(() => Boolean)
localStorageSupported!: boolean;
@Field(() => Boolean)
customEndpointSupported!: boolean;
@Field(() => Boolean)
hasAiPlan!: boolean;
@Field(() => [WorkspaceByokCapabilityWarningType])
warnings!: WorkspaceByokCapabilityWarningType[];
}
@ObjectType()
class WorkspaceByokUsagePointType {
@Field(() => Date)
date!: Date;
@Field(() => String)
featureKind!: string;
@Field(() => SafeIntResolver)
totalTokens!: number;
}
@ObjectType()
class TestWorkspaceByokConfigResultType {
@Field(() => Boolean)
ok!: boolean;
@Field(() => ByokKeyTestStatus)
status!: ByokKeyTestStatus;
@Field(() => String, { nullable: true })
message!: string | null;
}
@ObjectType()
class CreateWorkspaceByokLocalLeaseResultType {
@Field(() => String)
leaseId!: string;
@Field(() => Date)
expiresAt!: Date;
}
@InputType()
class UpsertWorkspaceByokConfigInput {
@Field(() => ID, { nullable: true })
id?: string;
@Field(() => String)
workspaceId!: string;
@Field(() => ByokProvider)
provider!: ByokProvider;
@Field(() => String)
name!: string;
@Field(() => String, { nullable: true })
description?: string | null;
@Field(() => ByokKeyStorage)
storage!: ByokKeyStorage;
@Field(() => String, { nullable: true })
apiKey?: string | null;
@Field(() => String, { nullable: true })
endpoint?: string | null;
@Field(() => SafeIntResolver, { nullable: true })
sortOrder?: number | null;
@Field(() => Boolean, { nullable: true })
enabled?: boolean | null;
}
@InputType()
class TestWorkspaceByokConfigInput {
@Field(() => String)
workspaceId!: string;
@Field(() => ByokProvider)
provider!: ByokProvider;
@Field(() => ByokKeyStorage)
storage!: ByokKeyStorage;
@Field(() => String, { nullable: true })
apiKey?: string | null;
@Field(() => String, { nullable: true })
endpoint?: string | null;
@Field(() => ID, { nullable: true })
configId?: string | null;
}
@InputType()
class ReorderWorkspaceByokConfigsInput {
@Field(() => String)
workspaceId!: string;
@Field(() => ByokKeyStorage)
storage!: ByokKeyStorage;
@Field(() => [ID])
ids!: string[];
}
@InputType()
class CreateWorkspaceByokLocalLeaseProviderInput implements ByokLocalLeaseProvider {
@Field(() => ByokProvider)
provider!: ByokProvider;
@Field(() => String)
name!: string;
@Field(() => String, { nullable: true })
description?: string | null;
@Field(() => String)
apiKey!: string;
@Field(() => String, { nullable: true })
endpoint?: string | null;
@Field(() => SafeIntResolver, { nullable: true })
sortOrder?: number | null;
@Field(() => Boolean, { nullable: true })
enabled?: boolean | null;
}
@InputType()
class CreateWorkspaceByokLocalLeaseInput {
@Field(() => String)
workspaceId!: string;
@Field(() => [CreateWorkspaceByokLocalLeaseProviderInput])
providers!: CreateWorkspaceByokLocalLeaseProviderInput[];
}
@Resolver(() => WorkspaceType)
export class WorkspaceByokResolver {
constructor(
private readonly ac: AccessController,
private readonly entitlement: ByokEntitlementPolicy,
private readonly byok: ByokService
) {}
@ResolveField(() => WorkspaceByokSettingsType, {
name: 'byokSettings',
complexity: 2,
})
async settings(
@CurrentUser() user: CurrentUser,
@Parent() workspace: WorkspaceType
) {
await this.ac
.user(user.id)
.workspace(workspace.id)
.allowLocal()
.assert('Workspace.Settings.Read');
await this.entitlement.assertManagementAccess(workspace.id, user.id);
return await this.byok.getSettings(workspace.id, user.id);
}
@ResolveField(() => [WorkspaceByokUsagePointType], {
name: 'byokUsage',
complexity: 2,
})
async usage(
@CurrentUser() user: CurrentUser,
@Parent() workspace: WorkspaceType,
@Args('from', { type: () => Date }) from: Date,
@Args('to', { type: () => Date }) to: Date
) {
await this.ac
.user(user.id)
.workspace(workspace.id)
.allowLocal()
.assert('Workspace.Settings.Read');
await this.entitlement.assertManagementAccess(workspace.id, user.id);
return await this.byok.getUsage(workspace.id, from, to);
}
@Throttle('strict')
@Mutation(() => TestWorkspaceByokConfigResultType)
async testWorkspaceByokConfig(
@CurrentUser() user: CurrentUser,
@Args('input') input: TestWorkspaceByokConfigInput
) {
await this.ac
.user(user.id)
.workspace(input.workspaceId)
.allowLocal()
.assert('Workspace.Settings.Update');
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
if (input.storage === ByokKeyStorage.server) {
await this.entitlement.assertServerEntitled(input.workspaceId);
} else {
await this.entitlement.assertLocalEntitled(input.workspaceId, user.id);
}
return await this.byok.testConfig({ ...input, userId: user.id });
}
@Mutation(() => WorkspaceByokKeyConfigType)
@Throttle('strict')
async upsertWorkspaceByokConfig(
@CurrentUser() user: CurrentUser,
@Args('input') input: UpsertWorkspaceByokConfigInput
) {
await this.ac
.user(user.id)
.workspace(input.workspaceId)
.allowLocal()
.assert('Workspace.Settings.Update');
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
await this.entitlement.assertServerEntitled(input.workspaceId);
return await this.byok.upsertConfig({ ...input, userId: user.id });
}
@Mutation(() => [WorkspaceByokKeyConfigType])
@Throttle('strict')
async reorderWorkspaceByokConfigs(
@CurrentUser() user: CurrentUser,
@Args('input') input: ReorderWorkspaceByokConfigsInput
) {
await this.ac
.user(user.id)
.workspace(input.workspaceId)
.allowLocal()
.assert('Workspace.Settings.Update');
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
await this.entitlement.assertServerEntitled(input.workspaceId);
return await this.byok.reorderConfigs({ ...input, userId: user.id });
}
@Mutation(() => Boolean)
@Throttle('strict')
async deleteWorkspaceByokConfig(
@CurrentUser() user: CurrentUser,
@Args('id', { type: () => ID }) id: string,
@Args('workspaceId', { type: () => String }) workspaceId: string
) {
await this.ac
.user(user.id)
.workspace(workspaceId)
.allowLocal()
.assert('Workspace.Settings.Update');
await this.entitlement.assertManagementAccess(workspaceId, user.id);
await this.entitlement.assertServerEntitled(workspaceId);
return await this.byok.deleteConfig(workspaceId, id, user.id);
}
@Mutation(() => Boolean)
@Throttle('strict')
async clearWorkspaceByokConfigs(
@CurrentUser() user: CurrentUser,
@Args('workspaceId', { type: () => String }) workspaceId: string,
@Args('provider', { type: () => ByokProvider, nullable: true })
provider?: ByokProvider | null
) {
await this.ac
.user(user.id)
.workspace(workspaceId)
.allowLocal()
.assert('Workspace.Settings.Update');
await this.entitlement.assertManagementAccess(workspaceId, user.id);
await this.entitlement.assertServerEntitled(workspaceId);
return await this.byok.clearConfigs(workspaceId, provider, user.id);
}
@Mutation(() => CreateWorkspaceByokLocalLeaseResultType)
@Throttle('strict')
async createWorkspaceByokLocalLease(
@CurrentUser() user: CurrentUser,
@Args('input') input: CreateWorkspaceByokLocalLeaseInput
) {
await this.ac
.user(user.id)
.workspace(input.workspaceId)
.allowLocal()
.assert('Workspace.Copilot');
await this.entitlement.assertManagementAccess(input.workspaceId, user.id);
await this.entitlement.assertLocalEntitled(input.workspaceId, user.id);
return await this.byok.createLocalLease({ ...input, userId: user.id });
}
}

View File

@@ -0,0 +1,883 @@
import { createHash, createHmac, randomUUID } from 'node:crypto';
import { BadRequestException, Injectable } from '@nestjs/common';
import { BadRequest, Cache, CryptoHelper, metrics } from '../../../base';
import { Models } from '../../../models';
import type { CopilotProviderProfile } from '../config';
import { ByokEntitlementPolicy } from './policy';
import {
BYOK_ALLOWED_PROVIDERS,
type ByokFeatureKind,
ByokKeyStorage,
ByokKeyTestStatus,
ByokProvider,
ByokProviderSource,
byokProviderToCopilotType,
isByokProvider,
} from './types';
const LOCAL_LEASE_TTL_MS = 10 * 60 * 1000;
const BYOK_PROFILE_PRIORITY_BASE = 10_000;
const SERVER_PROFILE_PRIORITY_OFFSET = 2_000;
const TEST_TIMEOUT_MS = 10_000;
export type ByokProviderRequestContext = {
userId?: string;
workspaceId?: string;
byokLeaseId?: string;
};
export type ByokProfileSourceFilter = {
local?: boolean;
server?: boolean;
};
export type ByokKeyConfig = {
id: string;
provider: ByokProvider;
name: string;
description: string | null;
storage: ByokKeyStorage;
configured: boolean;
enabled: boolean;
endpoint: string | null;
endpointEditable: boolean;
sortOrder: number;
capabilities: string[];
testStatus: ByokKeyTestStatus;
disabledReason: string | null;
lastTestedAt: Date | null;
lastTestError: string | null;
lastUsedAt: Date | null;
lastErrorAt: Date | null;
lastError: string | null;
};
export type ByokSettings = {
workspaceId: string;
entitled: boolean;
serverEntitled: boolean;
localEntitled: boolean;
entitlementRequired: string[];
keys: ByokKeyConfig[];
allowedProviders: ByokProvider[];
localStorageSupported: boolean;
customEndpointSupported: boolean;
hasAiPlan: boolean;
warnings: Array<{
featureKind: string;
reason: string;
requiredProviders: ByokProvider[];
}>;
};
export type ByokLocalLeaseProvider = {
provider: ByokProvider;
name: string;
description?: string | null;
apiKey: string;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
};
type LocalLeasePayload = {
workspaceId: string;
userId: string;
providers: Array<
Omit<ByokLocalLeaseProvider, 'apiKey'> & { encryptedApiKey: string }
>;
};
type LocalLeaseActive = {
leaseId: string;
expiresAt: string;
};
type ByokProfileMeta = {
source: ByokProviderSource.Server | ByokProviderSource.Local;
keyId?: string;
provider: ByokProvider;
};
@Injectable()
export class ByokService {
constructor(
private readonly models: Models,
private readonly crypto: CryptoHelper,
private readonly cache: Cache,
private readonly entitlement: ByokEntitlementPolicy
) {}
get customEndpointSupported() {
return env.selfhosted;
}
async getSettings(
workspaceId: string,
userId?: string
): Promise<ByokSettings> {
if (!(await this.entitlement.hasManagementAccess(workspaceId, userId))) {
return {
workspaceId,
entitled: false,
serverEntitled: false,
localEntitled: false,
entitlementRequired: ['Workspace owner or admin'],
keys: [],
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
localStorageSupported: false,
customEndpointSupported: this.customEndpointSupported,
hasAiPlan: await this.entitlement.hasAiPlan(userId),
warnings: [],
};
}
const [serverEntitled, localEntitled] =
await this.entitlement.hasEntitlement(workspaceId, userId);
const entitled = serverEntitled || localEntitled;
if (!entitled) {
return {
workspaceId,
entitled: false,
serverEntitled: false,
localEntitled: false,
entitlementRequired: ['Pro', 'Team', 'Believer'],
keys: [],
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
localStorageSupported: false,
customEndpointSupported: this.customEndpointSupported,
hasAiPlan: await this.entitlement.hasAiPlan(userId),
warnings: [],
};
}
const rows = serverEntitled
? await this.models.copilotWorkspaceByokConfig.list(workspaceId)
: [];
const keys = rows.map(row => this.toKeyConfig(row));
return {
workspaceId,
entitled: true,
serverEntitled,
localEntitled,
entitlementRequired: ['Pro', 'Team', 'Believer'],
keys,
allowedProviders: [...BYOK_ALLOWED_PROVIDERS],
localStorageSupported: false,
customEndpointSupported: this.customEndpointSupported,
hasAiPlan: await this.entitlement.hasAiPlan(userId),
warnings: this.buildWarnings(keys),
};
}
async upsertConfig(input: {
id?: string | null;
workspaceId: string;
provider: ByokProvider;
name: string;
description?: string | null;
storage: ByokKeyStorage;
apiKey?: string | null;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
userId?: string;
}): Promise<ByokKeyConfig> {
await this.entitlement.assertManagementAccess(
input.workspaceId,
input.userId
);
await this.entitlement.assertServerEntitled(input.workspaceId);
this.assertProvider(input.provider);
if (input.storage !== ByokKeyStorage.server) {
throw new BadRequestException('Only server BYOK keys are persisted.');
}
const existing = input.id
? await this.models.copilotWorkspaceByokConfig.get(input.id)
: null;
if (input.id && (!existing || existing.workspaceId !== input.workspaceId)) {
throw new BadRequest('BYOK config not found.');
}
const encryptedApiKey = input.apiKey
? this.crypto.encrypt(input.apiKey)
: undefined;
if (!input.id && !encryptedApiKey) {
throw new BadRequestException('apiKey is required.');
}
const description =
input.description !== undefined
? input.description?.trim() || null
: (existing?.description ?? null);
const endpoint =
input.endpoint !== undefined
? this.normalizeEndpoint(input.endpoint)
: (existing?.endpoint ?? null);
const sortOrder = input.sortOrder ?? existing?.sortOrder ?? 0;
const enabled = input.enabled ?? existing?.enabled ?? true;
const row = await this.models.copilotWorkspaceByokConfig.upsert({
id: input.id,
workspaceId: input.workspaceId,
provider: input.provider,
name: input.name.trim(),
description,
encryptedApiKey,
endpoint,
sortOrder,
enabled,
userId: input.userId,
});
return this.toKeyConfig(row);
}
async reorderConfigs(input: {
workspaceId: string;
storage: ByokKeyStorage;
ids: string[];
userId?: string;
}) {
await this.entitlement.assertManagementAccess(
input.workspaceId,
input.userId
);
await this.entitlement.assertServerEntitled(input.workspaceId);
if (input.storage !== ByokKeyStorage.server) {
throw new BadRequestException('Only server BYOK keys are persisted.');
}
await this.models.copilotWorkspaceByokConfig.reorder(
input.workspaceId,
input.ids,
input.userId
);
return (await this.getSettings(input.workspaceId, input.userId)).keys;
}
async deleteConfig(workspaceId: string, id: string, _userId?: string) {
await this.entitlement.assertManagementAccess(workspaceId, _userId);
await this.entitlement.assertServerEntitled(workspaceId);
await this.models.copilotWorkspaceByokConfig.delete(workspaceId, id);
return true;
}
async clearConfigs(
workspaceId: string,
provider: ByokProvider | null | undefined,
_userId?: string
) {
await this.entitlement.assertManagementAccess(workspaceId, _userId);
await this.entitlement.assertServerEntitled(workspaceId);
await this.models.copilotWorkspaceByokConfig.clear(workspaceId, provider);
return true;
}
async testConfig(input: {
workspaceId: string;
provider: ByokProvider;
storage: ByokKeyStorage;
apiKey?: string | null;
endpoint?: string | null;
configId?: string | null;
userId?: string;
}) {
await this.entitlement.assertManagementAccess(
input.workspaceId,
input.userId
);
if (input.storage === ByokKeyStorage.server) {
await this.entitlement.assertServerEntitled(input.workspaceId);
} else {
await this.entitlement.assertLocalEntitled(
input.workspaceId,
input.userId
);
}
this.assertProvider(input.provider);
let apiKey = input.apiKey;
let endpoint = this.normalizeEndpoint(input.endpoint);
if (!apiKey && input.configId && input.storage === ByokKeyStorage.server) {
const config = await this.models.copilotWorkspaceByokConfig.get(
input.configId
);
if (
!config ||
config.workspaceId !== input.workspaceId ||
config.provider !== input.provider
) {
throw new BadRequestException('BYOK config not found.');
}
apiKey = this.crypto.decrypt(config.encryptedApiKey);
endpoint =
input.endpoint !== undefined
? endpoint
: this.normalizeEndpoint(config.endpoint);
}
if (!apiKey) {
throw new BadRequestException('apiKey is required.');
}
try {
await this.runProviderProbe(input.provider, apiKey, endpoint);
if (input.configId && input.storage === ByokKeyStorage.server) {
await this.models.copilotWorkspaceByokConfig.markValidated(
input.workspaceId,
input.configId,
input.userId
);
}
metrics.ai.counter('byok_test_key').add(1, {
workspace: input.workspaceId,
provider: input.provider,
storage: input.storage,
result: 'passed',
});
return { ok: true, status: ByokKeyTestStatus.passed, message: null };
} catch (error) {
const message = this.sanitizeError(error);
if (input.configId && input.storage === ByokKeyStorage.server) {
await this.models.copilotWorkspaceByokConfig.markFailure(
input.workspaceId,
input.configId,
message
);
}
metrics.ai.counter('byok_test_key').add(1, {
workspace: input.workspaceId,
provider: input.provider,
storage: input.storage,
result: 'failed',
});
return { ok: false, status: ByokKeyTestStatus.failed, message };
}
}
async createLocalLease(input: {
workspaceId: string;
providers: ByokLocalLeaseProvider[];
userId: string;
}) {
await this.entitlement.assertManagementAccess(
input.workspaceId,
input.userId
);
await this.entitlement.assertLocalEntitled(input.workspaceId, input.userId);
const providers = input.providers.map(provider => {
this.assertProvider(provider.provider);
const endpoint = this.normalizeEndpoint(provider.endpoint);
return { ...provider, endpoint };
});
const activeCacheKey = this.localLeaseActiveCacheKey({
...input,
providers,
});
const activeLease = await this.getActiveLocalLease(activeCacheKey);
if (activeLease) return activeLease;
const leaseId = randomUUID();
const expiresAt = new Date(Date.now() + LOCAL_LEASE_TTL_MS);
const payload: LocalLeasePayload = {
workspaceId: input.workspaceId,
userId: input.userId,
providers: providers.map(provider => ({
provider: provider.provider,
name: provider.name,
description: provider.description,
encryptedApiKey: this.crypto.encrypt(provider.apiKey),
endpoint: provider.endpoint,
sortOrder: provider.sortOrder,
enabled: provider.enabled,
})),
};
await this.cache.set(this.leaseCacheKey(leaseId), payload, {
ttl: LOCAL_LEASE_TTL_MS,
});
const registered = await this.cache.setnx<LocalLeaseActive>(
activeCacheKey,
{ leaseId, expiresAt: expiresAt.toISOString() },
{ ttl: LOCAL_LEASE_TTL_MS }
);
if (!registered) {
const current = await this.getActiveLocalLease(activeCacheKey);
if (current) {
await this.cache.delete(this.leaseCacheKey(leaseId));
return current;
}
}
return { leaseId, expiresAt };
}
async getProfiles(
context: ByokProviderRequestContext = {},
sources: ByokProfileSourceFilter = { local: true, server: true }
): Promise<CopilotProviderProfile[]> {
if (!context.workspaceId) {
return [];
}
const [localEntitled, serverEntitled] = await Promise.all([
this.entitlement.hasLocalEntitlement(context.workspaceId, context.userId),
this.entitlement.hasServerEntitlement(context.workspaceId),
]);
const [localProfiles, serverProfiles] = await Promise.all([
sources.local && localEntitled
? this.getLocalProfiles(context)
: Promise.resolve([]),
sources.server && serverEntitled
? this.getServerProfiles(context.workspaceId)
: Promise.resolve([]),
]);
return [...localProfiles, ...serverProfiles];
}
async recordUsage(input: {
workspaceId?: string;
userId?: string;
providerId?: string;
model?: string | null;
featureKind: ByokFeatureKind;
sessionId?: string;
taskId?: string;
actionId?: string;
billingUnitId?: string;
usage?: {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
cached_tokens?: number;
};
}) {
if (!input.workspaceId || !input.providerId) return;
const meta = this.parseProfileMeta(input.providerId, input.workspaceId);
if (!meta) return;
metrics.ai.counter('byok_usage').add(1, {
workspace: input.workspaceId,
provider: meta.provider,
source: meta.source,
feature: input.featureKind,
});
await this.models.copilotUsage.create({
workspaceId: input.workspaceId,
userId: input.userId,
provider: meta.provider,
providerSource: meta.source,
featureKind: input.featureKind,
model: input.model ?? null,
sessionId: input.sessionId,
taskId: input.taskId,
actionId: input.actionId,
billingUnitId: input.billingUnitId,
promptTokens: input.usage?.prompt_tokens ?? 0,
completionTokens: input.usage?.completion_tokens ?? 0,
totalTokens: input.usage?.total_tokens ?? 0,
cachedTokens: input.usage?.cached_tokens ?? 0,
});
if (meta.source === ByokProviderSource.Server && meta.keyId) {
await this.models.copilotWorkspaceByokConfig.touchUsed(
input.workspaceId,
meta.keyId
);
}
}
async recordProviderFailure(input: {
workspaceId?: string;
providerId?: string;
featureKind: ByokFeatureKind;
error: unknown;
}) {
if (!input.workspaceId || !input.providerId) return;
const meta = this.parseProfileMeta(input.providerId, input.workspaceId);
if (!meta) return;
const message = this.sanitizeError(input.error);
metrics.ai.counter('byok_route_failure').add(1, {
workspace: input.workspaceId,
provider: meta.provider,
source: meta.source,
feature: input.featureKind,
});
if (meta.source === ByokProviderSource.Server && meta.keyId) {
await this.models.copilotWorkspaceByokConfig.markFailure(
input.workspaceId,
meta.keyId,
message
);
}
}
async getUsage(workspaceId: string, from: Date, to: Date) {
return await this.models.copilotUsage.aggregateByDay({
workspaceId,
from,
to,
providerSources: [ByokProviderSource.Server, ByokProviderSource.Local],
});
}
private async getServerProfiles(workspaceId: string) {
const rows =
await this.models.copilotWorkspaceByokConfig.listEnabled(workspaceId);
return rows
.filter(row => isByokProvider(row.provider))
.map((row, index): CopilotProviderProfile => {
const provider = row.provider as ByokProvider;
return {
id: this.profileId(workspaceId, provider, row.id, 'server'),
type: byokProviderToCopilotType(provider),
priority:
BYOK_PROFILE_PRIORITY_BASE - SERVER_PROFILE_PRIORITY_OFFSET - index,
config: this.providerConfig(
provider,
row.encryptedApiKey,
row.endpoint
),
} as CopilotProviderProfile;
});
}
private async getLocalProfiles(context: ByokProviderRequestContext) {
if (!context.byokLeaseId || !context.workspaceId || !context.userId) {
return [];
}
if (
!(await this.entitlement.hasManagementAccess(
context.workspaceId,
context.userId
))
) {
return [];
}
const lease = await this.cache.get<LocalLeasePayload>(
this.leaseCacheKey(context.byokLeaseId)
);
if (
!lease ||
lease.workspaceId !== context.workspaceId ||
lease.userId !== context.userId
) {
return [];
}
return lease.providers
.filter(provider => provider.enabled !== false)
.map((provider, index): CopilotProviderProfile => {
return {
id: this.profileId(
context.workspaceId ?? lease.workspaceId,
provider.provider,
`${index}`,
'local'
),
type: byokProviderToCopilotType(provider.provider),
priority: BYOK_PROFILE_PRIORITY_BASE - index,
config: this.providerConfig(
provider.provider,
provider.encryptedApiKey,
provider.endpoint ?? null
),
} as CopilotProviderProfile;
});
}
private providerConfig(
provider: ByokProvider,
encryptedApiKey: string,
endpoint: string | null
) {
const apiKey = this.crypto.decrypt(encryptedApiKey);
switch (provider) {
case ByokProvider.openai:
case ByokProvider.gemini:
case ByokProvider.anthropic:
return { apiKey, ...(endpoint ? { baseURL: endpoint } : {}) };
case ByokProvider.fal:
return { apiKey };
}
}
private profileId(
workspaceId: string,
provider: ByokProvider,
keyId: string,
storage: 'server' | 'local'
) {
const hash = this.workspaceHash(workspaceId);
const sanitizedKeyId = keyId.replaceAll(/[^a-zA-Z0-9-_]/g, '');
return storage === 'local'
? `byok-${hash}-${provider}-local-${sanitizedKeyId}`
: `byok-${hash}-${provider}-${sanitizedKeyId}`;
}
parseProfileMeta(
providerId: string,
workspaceId?: string
): ByokProfileMeta | null {
const match =
/^byok-([a-f0-9]{12})-(openai|anthropic|gemini|fal)-(.+)$/.exec(
providerId
);
if (!match) return null;
if (workspaceId && match[1] !== this.workspaceHash(workspaceId)) {
return null;
}
const keyId = match[3];
return {
provider: match[2] as ByokProvider,
source: keyId.startsWith('local-')
? ByokProviderSource.Local
: ByokProviderSource.Server,
keyId: keyId.startsWith('local-') ? undefined : keyId,
};
}
private toKeyConfig(row: {
id: string;
provider: string;
name: string;
description: string | null;
endpoint: string | null;
sortOrder: number;
enabled: boolean;
disabledReason: string | null;
lastValidatedAt: Date | null;
lastValidationError: string | null;
lastUsedAt: Date | null;
lastErrorAt: Date | null;
lastError: string | null;
}): ByokKeyConfig {
const provider = row.provider as ByokProvider;
return {
id: row.id,
provider,
name: row.name,
description: row.description,
storage: ByokKeyStorage.server,
configured: true,
enabled: row.enabled,
endpoint: row.endpoint,
endpointEditable: this.customEndpointSupported,
sortOrder: row.sortOrder,
capabilities: this.capabilities(provider, 'server'),
testStatus: row.lastValidationError
? ByokKeyTestStatus.failed
: row.lastValidatedAt
? ByokKeyTestStatus.passed
: ByokKeyTestStatus.untested,
disabledReason: row.disabledReason,
lastTestedAt: row.lastValidatedAt,
lastTestError: row.lastValidationError,
lastUsedAt: row.lastUsedAt,
lastErrorAt: row.lastErrorAt,
lastError: row.lastError,
};
}
private capabilities(provider: ByokProvider, storage: 'server' | 'local') {
switch (provider) {
case ByokProvider.openai:
return ['Text', 'Image input', 'Actions', 'Image generate'];
case ByokProvider.anthropic:
return ['Text', 'Image input'];
case ByokProvider.gemini:
return storage === 'server'
? [
'Text',
'Image input',
'Actions',
'Image generate',
'Transcript',
'Indexing',
]
: ['Text', 'Image input', 'Actions', 'Image generate'];
case ByokProvider.fal:
return ['Image generate'];
}
}
private buildWarnings(keys: ByokKeyConfig[]) {
const activeServerGemini = keys.some(
key =>
key.provider === ByokProvider.gemini &&
key.storage === ByokKeyStorage.server &&
key.enabled
);
if (activeServerGemini) {
return [];
}
return [
{
featureKind: 'transcript',
reason:
'Transcript and workspace indexing require a server Gemini BYOK key or AFFiNE AI plan fallback.',
requiredProviders: [ByokProvider.gemini],
},
{
featureKind: 'workspace_indexing',
reason:
'Workspace indexing requires a server Gemini BYOK key or AFFiNE AI plan fallback.',
requiredProviders: [ByokProvider.gemini],
},
];
}
private normalizeEndpoint(endpoint?: string | null) {
if (!endpoint) return null;
if (!this.customEndpointSupported) {
throw new BadRequestException('Custom BYOK endpoint is not supported.');
}
let parsed: URL;
try {
parsed = new URL(endpoint);
} catch {
throw new BadRequestException('Invalid BYOK endpoint.');
}
if (!['https:', 'http:'].includes(parsed.protocol)) {
throw new BadRequestException('BYOK endpoint must use HTTP or HTTPS.');
}
return parsed.toString().replace(/\/$/, '');
}
private assertProvider(provider: ByokProvider) {
if (!BYOK_ALLOWED_PROVIDERS.includes(provider)) {
throw new BadRequestException('Unsupported BYOK provider.');
}
}
private async runProviderProbe(
provider: ByokProvider,
apiKey: string,
endpoint: string | null
) {
const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), TEST_TIMEOUT_MS);
try {
const request = this.buildProbeRequest(provider, apiKey, endpoint);
const response = await fetch(request.url, {
method: request.method,
headers: request.headers as unknown as Record<string, string>,
signal: controller.signal,
});
if (!response.ok) {
throw new BadRequestException(
this.providerProbeFailureMessage(response.status)
);
}
} finally {
clearTimeout(timeout);
}
}
private buildProbeRequest(
provider: ByokProvider,
apiKey: string,
endpoint: string | null
) {
switch (provider) {
case ByokProvider.openai:
return {
method: 'GET',
url: `${endpoint ?? 'https://api.openai.com/v1'}/models`,
headers: { Authorization: `Bearer ${apiKey}` },
};
case ByokProvider.anthropic:
return {
method: 'GET',
url: `${endpoint ?? 'https://api.anthropic.com/v1'}/models`,
headers: {
'x-api-key': apiKey,
'anthropic-version': '2023-06-01',
},
};
case ByokProvider.gemini:
return {
method: 'GET',
url: `${endpoint ?? 'https://generativelanguage.googleapis.com/v1beta'}/models`,
headers: { 'x-goog-api-key': apiKey },
};
case ByokProvider.fal:
return {
method: 'GET',
url: 'https://api.fal.ai/v1/models?limit=10',
headers: { Authorization: `Key ${apiKey}` },
};
}
}
private sanitizeError(error: unknown) {
if (error instanceof Error && error.name === 'AbortError') {
return 'Provider key test timed out.';
}
if (error instanceof BadRequestException && error.message) {
return error.message.slice(0, 300);
}
return 'Provider request failed.';
}
private providerProbeFailureMessage(status: number) {
switch (status) {
case 401:
return 'Provider rejected the BYOK key.';
case 403:
return 'Provider rejected the BYOK key permissions.';
case 404:
return 'Provider probe endpoint was not found.';
case 429:
return 'Provider rate limit exceeded while testing the key.';
default:
return status >= 500
? 'Provider service is unavailable.'
: `Provider key test failed with HTTP ${status}.`;
}
}
private workspaceHash(workspaceId: string) {
return createHash('sha256').update(workspaceId).digest('hex').slice(0, 12);
}
private leaseCacheKey(leaseId: string) {
return `copilot:byok:lease:${leaseId}`;
}
private async getActiveLocalLease(activeCacheKey: string) {
const active = await this.cache.get<LocalLeaseActive>(activeCacheKey);
if (!active) return null;
if (await this.cache.has(this.leaseCacheKey(active.leaseId))) {
return { leaseId: active.leaseId, expiresAt: new Date(active.expiresAt) };
}
await this.cache.delete(activeCacheKey);
return null;
}
private localLeaseActiveCacheKey(input: {
workspaceId: string;
userId: string;
providers: ByokLocalLeaseProvider[];
}) {
const fingerprint = createHmac(
'sha256',
this.crypto.keyPair.sha256.privateKey
)
.update(
JSON.stringify(
input.providers.map(provider => ({
provider: provider.provider,
name: provider.name,
description: provider.description ?? null,
apiKey: provider.apiKey,
endpoint: provider.endpoint ?? null,
sortOrder: provider.sortOrder ?? 0,
enabled: provider.enabled ?? true,
}))
)
)
.digest('hex');
return `copilot:byok:lease:active:${input.workspaceId}:${input.userId}:${fingerprint}`;
}
}

View File

@@ -0,0 +1,79 @@
import { registerEnumType } from '@nestjs/graphql';
import { CopilotProviderType } from '../providers/types';
export enum ByokProvider {
openai = 'openai',
anthropic = 'anthropic',
gemini = 'gemini',
fal = 'fal',
}
export enum ByokKeyStorage {
server = 'server',
local = 'local',
}
export enum ByokKeyTestStatus {
untested = 'untested',
passed = 'passed',
failed = 'failed',
}
export enum ByokProviderSource {
Server = 'byok_server',
Local = 'byok_local',
AffinePlan = 'affine_plan',
}
export type ByokFeatureKind =
| 'chat'
| 'action'
| 'image'
| 'embedding'
| 'rerank'
| 'transcript'
| 'workspace_indexing';
export const BYOK_ALLOWED_PROVIDERS = [
ByokProvider.openai,
ByokProvider.anthropic,
ByokProvider.gemini,
ByokProvider.fal,
] as const;
export function byokProviderToCopilotType(provider: ByokProvider) {
switch (provider) {
case ByokProvider.openai:
return CopilotProviderType.OpenAI;
case ByokProvider.anthropic:
return CopilotProviderType.Anthropic;
case ByokProvider.gemini:
return CopilotProviderType.Gemini;
case ByokProvider.fal:
return CopilotProviderType.FAL;
}
}
export function copilotTypeToByokProvider(type: CopilotProviderType) {
switch (type) {
case CopilotProviderType.OpenAI:
return ByokProvider.openai;
case CopilotProviderType.Anthropic:
return ByokProvider.anthropic;
case CopilotProviderType.Gemini:
return ByokProvider.gemini;
case CopilotProviderType.FAL:
return ByokProvider.fal;
default:
return null;
}
}
export function isByokProvider(value: string): value is ByokProvider {
return (BYOK_ALLOWED_PROVIDERS as readonly string[]).includes(value);
}
registerEnumType(ByokProvider, { name: 'ByokProvider' });
registerEnumType(ByokKeyStorage, { name: 'ByokKeyStorage' });
registerEnumType(ByokKeyTestStatus, { name: 'ByokKeyTestStatus' });

View File

@@ -12,9 +12,7 @@ import {
import { CloudflareWorkersAIConfig } from './providers/cloudflare'; import { CloudflareWorkersAIConfig } from './providers/cloudflare';
import type { FalConfig } from './providers/fal'; import type { FalConfig } from './providers/fal';
import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini'; import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini';
import { MorphConfig } from './providers/morph';
import { OpenAIConfig } from './providers/openai'; import { OpenAIConfig } from './providers/openai';
import { PerplexityConfig } from './providers/perplexity';
import { import {
CopilotProviderType, CopilotProviderType,
ModelOutputType, ModelOutputType,
@@ -27,10 +25,8 @@ export type CopilotProviderConfigMap = {
[CopilotProviderType.FAL]: FalConfig; [CopilotProviderType.FAL]: FalConfig;
[CopilotProviderType.Gemini]: GeminiGenerativeConfig; [CopilotProviderType.Gemini]: GeminiGenerativeConfig;
[CopilotProviderType.GeminiVertex]: GeminiVertexConfig; [CopilotProviderType.GeminiVertex]: GeminiVertexConfig;
[CopilotProviderType.Perplexity]: PerplexityConfig;
[CopilotProviderType.Anthropic]: AnthropicOfficialConfig; [CopilotProviderType.Anthropic]: AnthropicOfficialConfig;
[CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig; [CopilotProviderType.AnthropicVertex]: AnthropicVertexConfig;
[CopilotProviderType.Morph]: MorphConfig;
}; };
export type ProviderSpecificConfig = export type ProviderSpecificConfig =
@@ -138,20 +134,11 @@ const VertexProviderConfigShape = z.object({
fetch: z.any().optional(), fetch: z.any().optional(),
}); });
const PerplexityConfigShape = z.object({
apiKey: z.string(),
endpoint: z.string().optional(),
});
const AnthropicOfficialConfigShape = z.object({ const AnthropicOfficialConfigShape = z.object({
apiKey: z.string(), apiKey: z.string(),
baseURL: z.string().optional(), baseURL: z.string().optional(),
}); });
const MorphConfigShape = z.object({
apiKey: z.string().optional(),
});
const CopilotProviderProfileShape = z.discriminatedUnion('type', [ const CopilotProviderProfileShape = z.discriminatedUnion('type', [
CopilotProviderProfileBaseShape.extend({ CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.OpenAI), type: z.literal(CopilotProviderType.OpenAI),
@@ -173,10 +160,6 @@ const CopilotProviderProfileShape = z.discriminatedUnion('type', [
type: z.literal(CopilotProviderType.GeminiVertex), type: z.literal(CopilotProviderType.GeminiVertex),
config: VertexProviderConfigShape, config: VertexProviderConfigShape,
}), }),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Perplexity),
config: PerplexityConfigShape,
}),
CopilotProviderProfileBaseShape.extend({ CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Anthropic), type: z.literal(CopilotProviderType.Anthropic),
config: AnthropicOfficialConfigShape, config: AnthropicOfficialConfigShape,
@@ -185,10 +168,6 @@ const CopilotProviderProfileShape = z.discriminatedUnion('type', [
type: z.literal(CopilotProviderType.AnthropicVertex), type: z.literal(CopilotProviderType.AnthropicVertex),
config: VertexProviderConfigShape, config: VertexProviderConfigShape,
}), }),
CopilotProviderProfileBaseShape.extend({
type: z.literal(CopilotProviderType.Morph),
config: MorphConfigShape,
}),
]); ]);
const CopilotProviderDefaultsShape = z.object({ const CopilotProviderDefaultsShape = z.object({
@@ -205,6 +184,13 @@ declare global {
interface AppConfigSchema { interface AppConfigSchema {
copilot: { copilot: {
enabled: boolean; enabled: boolean;
byok: {
enabled: ConfigItem<boolean>;
allowedProviders: ConfigItem<
Array<'openai' | 'anthropic' | 'gemini' | 'fal'>
>;
allowCustomEndpoint: ConfigItem<boolean>;
};
unsplash: ConfigItem<{ unsplash: ConfigItem<{
key: string; key: string;
}>; }>;
@@ -220,10 +206,8 @@ declare global {
fal: ConfigItem<FalConfig>; fal: ConfigItem<FalConfig>;
gemini: ConfigItem<GeminiGenerativeConfig>; gemini: ConfigItem<GeminiGenerativeConfig>;
geminiVertex: ConfigItem<GeminiVertexConfig>; geminiVertex: ConfigItem<GeminiVertexConfig>;
perplexity: ConfigItem<PerplexityConfig>;
anthropic: ConfigItem<AnthropicOfficialConfig>; anthropic: ConfigItem<AnthropicOfficialConfig>;
anthropicVertex: ConfigItem<AnthropicVertexConfig>; anthropicVertex: ConfigItem<AnthropicVertexConfig>;
morph: ConfigItem<MorphConfig>;
}; };
}; };
} }
@@ -234,6 +218,21 @@ defineModuleConfig('copilot', {
desc: 'Whether to enable the copilot plugin. <br> Document: <a href="https://docs.affine.pro/self-host-affine/administer/ai" target="_blank">https://docs.affine.pro/self-host-affine/administer/ai</a>', desc: 'Whether to enable the copilot plugin. <br> Document: <a href="https://docs.affine.pro/self-host-affine/administer/ai" target="_blank">https://docs.affine.pro/self-host-affine/administer/ai</a>',
default: false, default: false,
}, },
'byok.enabled': {
desc: 'Whether to enable workspace BYOK.',
default: true,
shape: z.boolean(),
},
'byok.allowedProviders': {
desc: 'The allowlist for workspace BYOK providers.',
default: ['openai', 'anthropic', 'gemini', 'fal'],
shape: z.array(z.enum(['openai', 'anthropic', 'gemini', 'fal'])),
},
'byok.allowCustomEndpoint': {
desc: 'Whether workspace BYOK custom endpoints are accepted.',
default: false,
shape: z.boolean(),
},
'providers.profiles': { 'providers.profiles': {
desc: 'The profile list for copilot providers.', desc: 'The profile list for copilot providers.',
default: [], default: [],
@@ -277,12 +276,6 @@ defineModuleConfig('copilot', {
default: {}, default: {},
schema: VertexSchema, schema: VertexSchema,
}, },
'providers.perplexity': {
desc: 'The config for the perplexity provider.',
default: {
apiKey: '',
},
},
'providers.anthropic': { 'providers.anthropic': {
desc: 'The config for the anthropic provider.', desc: 'The config for the anthropic provider.',
default: { default: {
@@ -295,10 +288,6 @@ defineModuleConfig('copilot', {
default: {}, default: {},
schema: VertexSchema, schema: VertexSchema,
}, },
'providers.morph': {
desc: 'The config for the morph provider.',
default: {},
},
unsplash: { unsplash: {
desc: 'The config for the unsplash key.', desc: 'The config for the unsplash key.',
default: { default: {

View File

@@ -15,7 +15,11 @@ import {
Models, Models,
} from '../../../models'; } from '../../../models';
import { CopilotEmbeddingClientService } from '../embedding/client'; import { CopilotEmbeddingClientService } from '../embedding/client';
import type { EmbeddingClient } from '../embedding/types'; import type {
EmbeddingCallOptions,
EmbeddingClient,
EmbeddingRouteContext,
} from '../embedding/types';
import { ContextSession } from './session'; import { ContextSession } from './session';
const CONTEXT_SESSION_KEY = 'context-session'; const CONTEXT_SESSION_KEY = 'context-session';
@@ -62,6 +66,14 @@ export class CopilotContextService implements OnApplicationBootstrap {
return this.client ?? this.embeddingClients.getClient(); return this.client ?? this.embeddingClients.getClient();
} }
private embeddingOptions(
workspaceId: string,
signal?: AbortSignal,
routeContext: EmbeddingRouteContext = {}
): EmbeddingCallOptions {
return { workspaceId, signal, ...routeContext, featureKind: 'embedding' };
}
private async saveConfig( private async saveConfig(
contextId: string, contextId: string,
config: ContextConfig, config: ContextConfig,
@@ -172,11 +184,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
content: string, content: string,
topK: number = 5, topK: number = 5,
signal?: AbortSignal, signal?: AbortSignal,
threshold: number = 0.5 threshold: number = 0.5,
routeContext?: EmbeddingRouteContext
) { ) {
const client = this.embeddingClient; const client = this.embeddingClient;
if (!client) return []; if (!client) return [];
const embedding = await client.getEmbedding(content, signal); const options = this.embeddingOptions(workspaceId, signal, routeContext);
const embedding = await client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const blobChunks = await this.models.copilotWorkspace.matchBlobEmbedding( const blobChunks = await this.models.copilotWorkspace.matchBlobEmbedding(
@@ -187,7 +201,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
); );
if (!blobChunks.length) return []; if (!blobChunks.length) return [];
return await client.reRank(content, blobChunks, topK, signal); return await client.reRank(content, blobChunks, topK, options);
} }
async matchWorkspaceFiles( async matchWorkspaceFiles(
@@ -195,11 +209,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
content: string, content: string,
topK: number = 5, topK: number = 5,
signal?: AbortSignal, signal?: AbortSignal,
threshold: number = 0.5 threshold: number = 0.5,
routeContext?: EmbeddingRouteContext
) { ) {
const client = this.embeddingClient; const client = this.embeddingClient;
if (!client) return []; if (!client) return [];
const embedding = await client.getEmbedding(content, signal); const options = this.embeddingOptions(workspaceId, signal, routeContext);
const embedding = await client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const fileChunks = await this.models.copilotWorkspace.matchFileEmbedding( const fileChunks = await this.models.copilotWorkspace.matchFileEmbedding(
@@ -210,7 +226,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
); );
if (!fileChunks.length) return []; if (!fileChunks.length) return [];
return await client.reRank(content, fileChunks, topK, signal); return await client.reRank(content, fileChunks, topK, options);
} }
async matchWorkspaceDocs( async matchWorkspaceDocs(
@@ -218,11 +234,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
content: string, content: string,
topK: number = 5, topK: number = 5,
signal?: AbortSignal, signal?: AbortSignal,
threshold: number = 0.5 threshold: number = 0.5,
routeContext?: EmbeddingRouteContext
) { ) {
const client = this.embeddingClient; const client = this.embeddingClient;
if (!client) return []; if (!client) return [];
const embedding = await client.getEmbedding(content, signal); const options = this.embeddingOptions(workspaceId, signal, routeContext);
const embedding = await client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const workspaceChunks = const workspaceChunks =
@@ -234,7 +252,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
); );
if (!workspaceChunks.length) return []; if (!workspaceChunks.length) return [];
return await client.reRank(content, workspaceChunks, topK, signal); return await client.reRank(content, workspaceChunks, topK, options);
} }
async matchWorkspaceAll( async matchWorkspaceAll(
@@ -244,11 +262,13 @@ export class CopilotContextService implements OnApplicationBootstrap {
signal?: AbortSignal, signal?: AbortSignal,
threshold: number = 0.8, threshold: number = 0.8,
docIds?: string[], docIds?: string[],
scopedThreshold: number = 0.85 scopedThreshold: number = 0.85,
routeContext?: EmbeddingRouteContext
) { ) {
const client = this.embeddingClient; const client = this.embeddingClient;
if (!client) return []; if (!client) return [];
const embedding = await client.getEmbedding(content, signal); const options = this.embeddingOptions(workspaceId, signal, routeContext);
const embedding = await client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const [fileChunks, blobChunks, workspaceChunks, scopedWorkspaceChunks] = const [fileChunks, blobChunks, workspaceChunks, scopedWorkspaceChunks] =
@@ -300,7 +320,7 @@ export class CopilotContextService implements OnApplicationBootstrap {
...(scopedWorkspaceChunks || []), ...(scopedWorkspaceChunks || []),
], ],
topK, topK,
signal options
); );
} }

View File

@@ -11,7 +11,11 @@ import {
FileChunkSimilarity, FileChunkSimilarity,
Models, Models,
} from '../../../models'; } from '../../../models';
import { EmbeddingClient } from '../embedding/types'; import type {
EmbeddingCallOptions,
EmbeddingClient,
EmbeddingRouteContext,
} from '../embedding/types';
export class ContextSession implements AsyncDisposable { export class ContextSession implements AsyncDisposable {
constructor( constructor(
@@ -69,6 +73,18 @@ export class ContextSession implements AsyncDisposable {
); );
} }
private embeddingOptions(
signal?: AbortSignal,
routeContext: EmbeddingRouteContext = {}
): EmbeddingCallOptions {
return {
workspaceId: this.workspaceId,
signal,
...routeContext,
featureKind: 'embedding',
};
}
async addCategoryRecord(type: ContextCategories, id: string, docs: string[]) { async addCategoryRecord(type: ContextCategories, id: string, docs: string[]) {
const category = this.config.categories.find( const category = this.config.categories.find(
c => c.type === type && c.id === id c => c.type === type && c.id === id
@@ -269,10 +285,12 @@ export class ContextSession implements AsyncDisposable {
topK: number = 5, topK: number = 5,
signal?: AbortSignal, signal?: AbortSignal,
scopedThreshold: number = 0.85, scopedThreshold: number = 0.85,
threshold: number = 0.5 threshold: number = 0.5,
routeContext?: EmbeddingRouteContext
): Promise<FileChunkSimilarity[]> { ): Promise<FileChunkSimilarity[]> {
if (!this.client) return []; if (!this.client) return [];
const embedding = await this.client.getEmbedding(content, signal); const options = this.embeddingOptions(signal, routeContext);
const embedding = await this.client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const [context, workspace] = await Promise.all([ const [context, workspace] = await Promise.all([
@@ -305,7 +323,7 @@ export class ContextSession implements AsyncDisposable {
...workspace, ...workspace,
], ],
topK, topK,
signal options
); );
} }
@@ -322,10 +340,12 @@ export class ContextSession implements AsyncDisposable {
topK: number = 5, topK: number = 5,
signal?: AbortSignal, signal?: AbortSignal,
scopedThreshold: number = 0.85, scopedThreshold: number = 0.85,
threshold: number = 0.5 threshold: number = 0.5,
routeContext?: EmbeddingRouteContext
) { ) {
if (!this.client) return []; if (!this.client) return [];
const embedding = await this.client.getEmbedding(content, signal); const options = this.embeddingOptions(signal, routeContext);
const embedding = await this.client.getEmbedding(content, options);
if (!embedding) return []; if (!embedding) return [];
const docIds = this.docIds; const docIds = this.docIds;
@@ -349,7 +369,7 @@ export class ContextSession implements AsyncDisposable {
content, content,
[...inContext, ...workspace], [...inContext, ...workspace],
topK, topK,
signal options
); );
// sort result, doc recorded in context first // sort result, doc recorded in context first

View File

@@ -1,7 +1,7 @@
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { CopilotQuotaExceeded } from '../../../base'; import { CopilotQuotaExceeded } from '../../../base';
import { QuotaService } from '../../../core/quota'; import { QuotaService } from '../../../core/quota/service';
import { Models } from '../../../models'; import { Models } from '../../../models';
import type { Turn } from '../core'; import type { Turn } from '../core';
import type { ResolvedPrompt } from '../prompt'; import type { ResolvedPrompt } from '../prompt';
@@ -31,12 +31,16 @@ export class ConversationPolicy {
} }
async checkQuota(userId: string) { async checkQuota(userId: string) {
const { limit, used } = await this.getQuota(userId); if (!(await this.hasQuota(userId))) {
if (limit && Number.isFinite(limit) && used >= limit) {
throw new CopilotQuotaExceeded(); throw new CopilotQuotaExceeded();
} }
} }
async hasQuota(userId: string) {
const { limit, used } = await this.getQuota(userId);
return !(limit !== undefined && Number.isFinite(limit) && used >= limit);
}
shouldScheduleTitle(prompt: Pick<ResolvedPrompt, 'action'>) { shouldScheduleTitle(prompt: Pick<ResolvedPrompt, 'action'>) {
return !prompt.action; return !prompt.action;
} }

View File

@@ -11,7 +11,12 @@ import {
import { type CopilotRerankRequest } from '../providers/types'; import { type CopilotRerankRequest } from '../providers/types';
import { CapabilityRuntime } from '../runtime/capability-runtime'; import { CapabilityRuntime } from '../runtime/capability-runtime';
import { TaskPolicy } from '../runtime/task-policy'; import { TaskPolicy } from '../runtime/task-policy';
import { EmbeddingClient, type ReRankResult } from './types'; import {
type EmbeddingCallOptionsInput,
EmbeddingClient,
normalizeEmbeddingCallOptions,
type ReRankResult,
} from './types';
class ProductionEmbeddingClient extends EmbeddingClient { class ProductionEmbeddingClient extends EmbeddingClient {
private readonly logger = new Logger(ProductionEmbeddingClient.name); private readonly logger = new Logger(ProductionEmbeddingClient.name);
@@ -35,10 +40,19 @@ class ProductionEmbeddingClient extends EmbeddingClient {
return result; return result;
} }
async getEmbeddings(input: string[]): Promise<Embedding[]> { async getEmbeddings(
input: string[],
options?: EmbeddingCallOptionsInput
): Promise<Embedding[]> {
const normalizedOptions = normalizeEmbeddingCallOptions(options);
const modelId = this.taskPolicy.resolveEmbeddingModelId(); const modelId = this.taskPolicy.resolveEmbeddingModelId();
const embeddings = await this.runtime.embed(modelId, input, { const embeddings = await this.runtime.embed(modelId, input, {
dimensions: EMBEDDING_DIMENSIONS, dimensions: EMBEDDING_DIMENSIONS,
signal: normalizedOptions.signal,
user: normalizedOptions.userId,
workspace: normalizedOptions.workspaceId,
byokLeaseId: normalizedOptions.byokLeaseId,
featureKind: normalizedOptions.featureKind ?? 'embedding',
}); });
if (embeddings.length !== input.length) { if (embeddings.length !== input.length) {
throw new CopilotFailedToGenerateEmbedding({ throw new CopilotFailedToGenerateEmbedding({
@@ -67,8 +81,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
>( >(
query: string, query: string,
embeddings: Chunk[], embeddings: Chunk[],
signal?: AbortSignal options?: EmbeddingCallOptionsInput
): Promise<ReRankResult> { ): Promise<ReRankResult> {
const normalizedOptions = normalizeEmbeddingCallOptions(options);
if (!embeddings.length) return []; if (!embeddings.length) return [];
const rerankRequest: CopilotRerankRequest = { const rerankRequest: CopilotRerankRequest = {
@@ -82,7 +97,13 @@ class ProductionEmbeddingClient extends EmbeddingClient {
const ranks = await this.runtime.rerank( const ranks = await this.runtime.rerank(
this.taskPolicy.resolveRerankModelId(), this.taskPolicy.resolveRerankModelId(),
rerankRequest, rerankRequest,
{ signal } {
signal: normalizedOptions.signal,
user: normalizedOptions.userId,
workspace: normalizedOptions.workspaceId,
byokLeaseId: normalizedOptions.byokLeaseId,
featureKind: 'rerank',
}
); );
try { try {
@@ -105,8 +126,9 @@ class ProductionEmbeddingClient extends EmbeddingClient {
query: string, query: string,
embeddings: Chunk[], embeddings: Chunk[],
topK: number, topK: number,
signal?: AbortSignal options?: EmbeddingCallOptionsInput
): Promise<Chunk[]> { ): Promise<Chunk[]> {
const normalizedOptions = normalizeEmbeddingCallOptions(options);
// search in context and workspace may find same chunks, de-duplicate them // search in context and workspace may find same chunks, de-duplicate them
const { deduped: dedupedEmbeddings } = embeddings.reduce( const { deduped: dedupedEmbeddings } = embeddings.reduce(
(acc, e) => { (acc, e) => {
@@ -138,14 +160,19 @@ class ProductionEmbeddingClient extends EmbeddingClient {
const ranks = await this.getEmbeddingRelevance( const ranks = await this.getEmbeddingRelevance(
query, query,
sortedEmbeddings, sortedEmbeddings,
signal normalizedOptions
); );
if (sortedEmbeddings.length !== ranks.length) { if (sortedEmbeddings.length !== ranks.length) {
// llm return wrong result, fallback to default sorting // llm return wrong result, fallback to default sorting
this.logger.warn( this.logger.warn(
`Batch size mismatch: expected ${sortedEmbeddings.length}, got ${ranks.length}` `Batch size mismatch: expected ${sortedEmbeddings.length}, got ${ranks.length}`
); );
return await super.reRank(query, dedupedEmbeddings, topK, signal); return await super.reRank(
query,
dedupedEmbeddings,
topK,
normalizedOptions
);
} }
const highConfidenceChunks = ranks const highConfidenceChunks = ranks
@@ -164,7 +191,12 @@ class ProductionEmbeddingClient extends EmbeddingClient {
return highConfidenceChunks.slice(0, topK); return highConfidenceChunks.slice(0, topK);
} catch (error) { } catch (error) {
this.logger.warn('ReRank failed, falling back to default sorting', error); this.logger.warn('ReRank failed, falling back to default sorting', error);
return await super.reRank(query, dedupedEmbeddings, topK, signal); return await super.reRank(
query,
dedupedEmbeddings,
topK,
normalizedOptions
);
} }
} }
} }
@@ -180,7 +212,8 @@ export class CopilotEmbeddingClientService {
async refresh() { async refresh() {
const client = new ProductionEmbeddingClient(this.taskPolicy, this.runtime); const client = new ProductionEmbeddingClient(this.taskPolicy, this.runtime);
this.client = (await client.configured()) ? client : undefined; await client.configured();
this.client = client;
return this.client; return this.client;
} }

View File

@@ -18,7 +18,7 @@ import { Models } from '../../../models';
import { CopilotStorage } from '../storage'; import { CopilotStorage } from '../storage';
import { readStream } from '../utils'; import { readStream } from '../utils';
import { CopilotEmbeddingClientService } from './client'; import { CopilotEmbeddingClientService } from './client';
import type { Chunk, DocFragment } from './types'; import type { Chunk, DocFragment, EmbeddingCallOptions } from './types';
import { EmbeddingClient } from './types'; import { EmbeddingClient } from './types';
@Injectable() @Injectable()
@@ -242,6 +242,19 @@ export class CopilotEmbeddingJob {
return new File([buffer], fileName); return new File([buffer], fileName);
} }
private workspaceIndexingOptions(
workspaceId: string,
signal?: AbortSignal,
userId?: string
): EmbeddingCallOptions {
return {
workspaceId,
userId,
signal,
featureKind: 'workspace_indexing',
};
}
@OnJob('copilot.embedding.files') @OnJob('copilot.embedding.files')
async embedPendingFile({ async embedPendingFile({
userId, userId,
@@ -266,7 +279,10 @@ export class CopilotEmbeddingJob {
const total = chunks.reduce((acc, c) => acc + c.length, 0); const total = chunks.reduce((acc, c) => acc + c.length, 0);
for (const chunk of chunks) { for (const chunk of chunks) {
const embeddings = await this.embeddingClient.generateEmbeddings(chunk); const embeddings = await this.embeddingClient.generateEmbeddings(
chunk,
this.workspaceIndexingOptions(workspaceId, undefined, userId)
);
if (contextId) { if (contextId) {
// for context files // for context files
await this.models.copilotContext.insertFileEmbedding( await this.models.copilotContext.insertFileEmbedding(
@@ -320,7 +336,10 @@ export class CopilotEmbeddingJob {
const total = chunks.reduce((acc, c) => acc + c.length, 0); const total = chunks.reduce((acc, c) => acc + c.length, 0);
for (const chunk of chunks) { for (const chunk of chunks) {
const embeddings = await this.embeddingClient.generateEmbeddings(chunk); const embeddings = await this.embeddingClient.generateEmbeddings(
chunk,
this.workspaceIndexingOptions(workspaceId)
);
await this.models.copilotWorkspace.insertBlobEmbeddings( await this.models.copilotWorkspace.insertBlobEmbeddings(
workspaceId, workspaceId,
blobId, blobId,
@@ -462,7 +481,7 @@ export class CopilotEmbeddingJob {
`${fragment.title || 'Untitled'}.md` `${fragment.title || 'Untitled'}.md`
), ),
chunks => this.formatDocChunks(chunks, fragment), chunks => this.formatDocChunks(chunks, fragment),
signal this.workspaceIndexingOptions(workspaceId, signal)
); );
for (const chunks of embeddings) { for (const chunks of embeddings) {

View File

@@ -6,6 +6,7 @@ import { CopilotContextFileNotSupported } from '../../../base';
import type { PageDocContent } from '../../../core/utils/blocksuite'; import type { PageDocContent } from '../../../core/utils/blocksuite';
import { ChunkSimilarity, Embedding } from '../../../models'; import { ChunkSimilarity, Embedding } from '../../../models';
import { parseDoc } from '../../../native'; import { parseDoc } from '../../../native';
import type { ByokFeatureKind } from '../byok/types';
declare global { declare global {
interface Events { interface Events {
@@ -103,6 +104,35 @@ export type Chunk = {
content: string; content: string;
}; };
export type EmbeddingCallOptions = {
signal?: AbortSignal;
userId?: string;
workspaceId?: string;
byokLeaseId?: string;
featureKind?: Extract<
ByokFeatureKind,
'embedding' | 'workspace_indexing' | 'rerank'
>;
};
export type EmbeddingCallOptionsInput = AbortSignal | EmbeddingCallOptions;
export type EmbeddingRouteContext = Pick<
EmbeddingCallOptions,
'userId' | 'byokLeaseId'
>;
export function normalizeEmbeddingCallOptions(
options?: EmbeddingCallOptionsInput
): EmbeddingCallOptions {
if (!options) {
return {};
}
if ('aborted' in options && 'addEventListener' in options) {
return { signal: options };
}
return options;
}
export abstract class EmbeddingClient { export abstract class EmbeddingClient {
async configured() { async configured() {
return true; return true;
@@ -111,11 +141,14 @@ export abstract class EmbeddingClient {
async getFileEmbeddings( async getFileEmbeddings(
file: File, file: File,
chunkMapper: (chunk: Chunk[]) => Chunk[], chunkMapper: (chunk: Chunk[]) => Chunk[],
signal?: AbortSignal options?: EmbeddingCallOptionsInput
): Promise<Embedding[][]> { ): Promise<Embedding[][]> {
const chunks = await this.getFileChunks(file, signal); const normalizedOptions = normalizeEmbeddingCallOptions(options);
const chunks = await this.getFileChunks(file, normalizedOptions.signal);
const chunkedEmbeddings = await Promise.all( const chunkedEmbeddings = await Promise.all(
chunks.map(chunk => this.generateEmbeddings(chunkMapper(chunk))) chunks.map(chunk =>
this.generateEmbeddings(chunkMapper(chunk), normalizedOptions)
)
); );
return chunkedEmbeddings; return chunkedEmbeddings;
} }
@@ -154,8 +187,9 @@ export abstract class EmbeddingClient {
async generateEmbeddings( async generateEmbeddings(
chunks: Chunk[], chunks: Chunk[],
signal?: AbortSignal options?: EmbeddingCallOptionsInput
): Promise<Embedding[]> { ): Promise<Embedding[]> {
const normalizedOptions = normalizeEmbeddingCallOptions(options);
const retry = 3; const retry = 3;
let embeddings: Embedding[] = []; let embeddings: Embedding[] = [];
@@ -164,7 +198,7 @@ export abstract class EmbeddingClient {
try { try {
embeddings = await this.getEmbeddings( embeddings = await this.getEmbeddings(
chunks.map(c => c.content), chunks.map(c => c.content),
signal normalizedOptions
); );
break; break;
} catch (e) { } catch (e) {
@@ -181,7 +215,7 @@ export abstract class EmbeddingClient {
_query: string, _query: string,
embeddings: Chunk[], embeddings: Chunk[],
topK: number, topK: number,
_signal?: AbortSignal _options?: EmbeddingCallOptionsInput
): Promise<Chunk[]> { ): Promise<Chunk[]> {
// sort by distance with ascending order // sort by distance with ascending order
return embeddings return embeddings
@@ -189,14 +223,14 @@ export abstract class EmbeddingClient {
.slice(0, topK); .slice(0, topK);
} }
async getEmbedding(query: string, signal?: AbortSignal) { async getEmbedding(query: string, options?: EmbeddingCallOptionsInput) {
const embedding = await this.getEmbeddings([query], signal); const embedding = await this.getEmbeddings([query], options);
return embedding?.[0]?.embedding; return embedding?.[0]?.embedding;
} }
abstract getEmbeddings( abstract getEmbeddings(
input: string[], input: string[],
signal?: AbortSignal options?: EmbeddingCallOptionsInput
): Promise<Embedding[]>; ): Promise<Embedding[]>;
} }

View File

@@ -1,3 +1,9 @@
import { CopilotAccessPolicy } from './access';
import {
ByokEntitlementPolicy,
ByokService,
WorkspaceByokResolver,
} from './byok';
import { HistoryAttachmentUrlProjector } from './compat/history-attachment-url-projector'; import { HistoryAttachmentUrlProjector } from './compat/history-attachment-url-projector';
import { CompatHistoryProjector } from './compat/history-projector'; import { CompatHistoryProjector } from './compat/history-projector';
import { HistoryPromptPreloadProjector } from './compat/history-prompt-preload-projector'; import { HistoryPromptPreloadProjector } from './compat/history-prompt-preload-projector';
@@ -64,10 +70,13 @@ export const COPILOT_PROVIDER_PROVIDERS = [
]; ];
export const COPILOT_RUNTIME_PROVIDERS = [ export const COPILOT_RUNTIME_PROVIDERS = [
ByokEntitlementPolicy,
ByokService,
ChatSessionService, ChatSessionService,
ConversationStore, ConversationStore,
ConversationInboxService, ConversationInboxService,
ConversationPolicy, ConversationPolicy,
CopilotAccessPolicy,
HistoryAttachmentUrlProjector, HistoryAttachmentUrlProjector,
CompatHistoryProjector, CompatHistoryProjector,
HistoryPromptPreloadProjector, HistoryPromptPreloadProjector,
@@ -114,6 +123,7 @@ export const COPILOT_RESOLVER_PROVIDERS = [
CopilotResolver, CopilotResolver,
UserCopilotResolver, UserCopilotResolver,
CopilotContextRootResolver, CopilotContextRootResolver,
WorkspaceByokResolver,
]; ];
export const COPILOT_JOB_PROVIDERS = [CopilotEmbeddingJob, CopilotCronJobs]; export const COPILOT_JOB_PROVIDERS = [CopilotEmbeddingJob, CopilotCronJobs];

View File

@@ -1,12 +1,11 @@
import { CopilotProviderSideError, UserFriendlyError } from '../../../base'; import { CopilotProviderSideError, UserFriendlyError } from '../../../base';
import { type LlmBackendConfig } from '../../../native'; import { type LlmBackendConfig } from '../../../native';
import type { CopilotTool } from '../tools';
import { CopilotProvider } from './provider'; import { CopilotProvider } from './provider';
import { import {
type CopilotProviderExecution, type CopilotProviderExecution,
type ProviderDriverSpec, type ProviderDriverSpec,
} from './provider-runtime-contract'; } from './provider-runtime-contract';
import { type CopilotChatTools, CopilotProviderType } from './types'; import { CopilotProviderType } from './types';
export type CloudflareWorkersAIConfig = { export type CloudflareWorkersAIConfig = {
apiToken: string; apiToken: string;
@@ -25,16 +24,6 @@ export class CloudflareWorkersAIProvider extends CopilotProvider<CloudflareWorke
const config = this.getConfig(execution); const config = this.getConfig(execution);
return !!config.apiToken && (!!config.accountId || !!config.baseURL); return !!config.apiToken && (!!config.accountId || !!config.baseURL);
} }
override getProviderSpecificTools(
toolName: CopilotChatTools,
_model: string
): [string, CopilotTool?] | undefined {
if (toolName === 'docEdit') {
return ['doc_edit', undefined];
}
return;
}
private handleError(e: any) { private handleError(e: any) {
if (e instanceof UserFriendlyError) { if (e instanceof UserFriendlyError) {
return e; return e;

View File

@@ -1,10 +1,14 @@
import { Injectable, Logger } from '@nestjs/common'; import { Injectable, Logger } from '@nestjs/common';
import { CopilotQuotaExceeded } from '../../../base';
import { ServerFeature, ServerService } from '../../../core'; import { ServerFeature, ServerService } from '../../../core';
import { type CopilotAccessContext, CopilotAccessPolicy } from '../access';
import type { RequiredStructuredOutputContract } from '../runtime/contracts'; import type { RequiredStructuredOutputContract } from '../runtime/contracts';
import { getProviderRuntimeHost } from '../runtime/provider-runtime-context'; import { getProviderRuntimeHost } from '../runtime/provider-runtime-context';
import type { CopilotProvider } from './provider'; import type { CopilotProvider } from './provider';
import { import {
buildProviderRegistry,
type CopilotProviderRegistry,
type NormalizedCopilotProviderProfile, type NormalizedCopilotProviderProfile,
resolveModel, resolveModel,
stripProviderPrefix, stripProviderPrefix,
@@ -57,11 +61,18 @@ type RoutePreparationResult = Partial<
> >
>; >;
type EffectiveProviderRegistry = {
byokRegistry: CopilotProviderRegistry;
quotaBackedRegistry: CopilotProviderRegistry;
quotaBackedRoutesAvailable: boolean;
};
@Injectable() @Injectable()
export class CopilotProviderFactory { export class CopilotProviderFactory {
constructor( constructor(
private readonly server: ServerService, private readonly server: ServerService,
private readonly registries: CopilotProviderRegistryService private readonly registries: CopilotProviderRegistryService,
private readonly access: CopilotAccessPolicy
) {} ) {}
private readonly logger = new Logger(CopilotProviderFactory.name); private readonly logger = new Logger(CopilotProviderFactory.name);
@@ -73,20 +84,84 @@ export class CopilotProviderFactory {
return this.registries.getRegistry(); return this.registries.getRegistry();
} }
private getPreferredProviderIds(type?: CopilotProviderType) { private getProviderByProfile(
providerId: string,
profile: NormalizedCopilotProviderProfile
) {
return (
this.#providers.get(providerId) ??
Array.from(this.#providerIdsByType.get(profile.type) ?? [])
.map(id => this.#providers.get(id))
.find((provider): provider is CopilotProvider => !!provider)
);
}
private providerAvailable(
providerId: string,
profile: NormalizedCopilotProviderProfile
) {
return !!this.getProviderByProfile(providerId, profile);
}
private getAvailableProviderIds(registry: CopilotProviderRegistry) {
return Array.from(registry.profiles.entries())
.filter(([providerId, profile]) =>
this.providerAvailable(providerId, profile)
)
.map(([providerId]) => providerId);
}
private getPreferredProviderIds(
registry: CopilotProviderRegistry,
type?: CopilotProviderType
) {
if (!type) return undefined; if (!type) return undefined;
return this.#providerIdsByType.get(type); return registry.byType.get(type)?.filter(providerId => {
const profile = registry.profiles.get(providerId);
return profile ? this.providerAvailable(providerId, profile) : false;
});
} }
private normalizeCond( private normalizeCond(
registry: CopilotProviderRegistry,
providerId: string, providerId: string,
cond: ModelFullConditions cond: ModelFullConditions
): ModelFullConditions { ): ModelFullConditions {
const registry = this.getRegistry();
const modelId = stripProviderPrefix(registry, providerId, cond.modelId); const modelId = stripProviderPrefix(registry, providerId, cond.modelId);
return { ...cond, modelId }; return { ...cond, modelId };
} }
private async getEffectiveRegistry(
context: CopilotAccessContext = {}
): Promise<EffectiveProviderRegistry> {
const quotaBackedRegistry = this.getRegistry();
const routeAccess = await this.access.resolveRouteAccess(context);
return {
byokRegistry: buildProviderRegistry({
profiles: routeAccess.byokProfiles,
defaults: {},
}),
quotaBackedRegistry,
quotaBackedRoutesAvailable: routeAccess.quotaBackedRoutesAvailable,
};
}
private getRequestContext(
options?:
| CopilotChatOptions
| CopilotStructuredOptions
| CopilotImageOptions
): CopilotAccessContext {
return {
userId: options?.user,
workspaceId: options?.workspace,
byokLeaseId: options?.byokLeaseId,
featureKind: options?.featureKind,
quotaBackedRoutesAllowed: options?.quotaBackedRoutesAllowed,
};
}
private filterPreparedRoutes(routes: Array<ResolvedCopilotProvider | null>) { private filterPreparedRoutes(routes: Array<ResolvedCopilotProvider | null>) {
return routes.filter( return routes.filter(
(route): route is ResolvedCopilotProvider => route !== null (route): route is ResolvedCopilotProvider => route !== null
@@ -113,36 +188,89 @@ export class CopilotProviderFactory {
cond: ModelFullConditions, cond: ModelFullConditions,
filter: { filter: {
prefer?: CopilotProviderType; prefer?: CopilotProviderType;
} = {} } = {},
context: CopilotAccessContext = {}
): Promise<ResolvedCopilotProvider | null> { ): Promise<ResolvedCopilotProvider | null> {
return (await this.resolveRoutes(cond, filter))[0] ?? null; return (await this.resolveRoutes(cond, filter, context))[0] ?? null;
} }
async resolveRoutes( async resolveRoutes(
cond: ModelFullConditions, cond: ModelFullConditions,
filter: { filter: {
prefer?: CopilotProviderType; prefer?: CopilotProviderType;
} = {} } = {},
context: CopilotAccessContext = {}
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
this.logger.debug( this.logger.debug(
`Resolving copilot provider for output type: ${cond.outputType}` `Resolving copilot provider for output type: ${cond.outputType}`
); );
const registry = this.getRegistry(); const { byokRegistry, quotaBackedRegistry, quotaBackedRoutesAvailable } =
await this.getEffectiveRegistry(context);
const byokRoutes = await this.resolveRoutesFromRegistry(
byokRegistry,
cond,
filter
);
const resolved = byokRoutes.length
? byokRoutes
: quotaBackedRoutesAvailable
? await this.resolveRoutesFromRegistry(
quotaBackedRegistry,
cond,
filter
)
: [];
for (const route of resolved) {
this.logger.debug(
`Copilot provider candidate found: ${route.provider.type} (${route.providerId})`
);
}
if (
!resolved.length &&
!quotaBackedRoutesAvailable &&
context.quotaBackedRoutesAllowed !== false
) {
const quotaBackedRoutes = await this.resolveRoutesFromRegistry(
quotaBackedRegistry,
cond,
filter
);
if (quotaBackedRoutes.length) {
throw new CopilotQuotaExceeded();
}
}
return resolved;
}
private async resolveRoutesFromRegistry(
registry: CopilotProviderRegistry,
cond: ModelFullConditions,
filter: {
prefer?: CopilotProviderType;
} = {}
): Promise<ResolvedCopilotProvider[]> {
const route = resolveModel({ const route = resolveModel({
registry, registry,
modelId: cond.modelId, modelId: cond.modelId,
outputType: cond.outputType, outputType: cond.outputType,
availableProviderIds: this.#providers.keys(), availableProviderIds: this.getAvailableProviderIds(registry),
preferredProviderIds: this.getPreferredProviderIds(filter.prefer), preferredProviderIds: this.getPreferredProviderIds(
registry,
filter.prefer
),
}); });
const resolved: ResolvedCopilotProvider[] = []; const resolved: ResolvedCopilotProvider[] = [];
for (const providerId of route.candidateProviderIds) { for (const providerId of route.candidateProviderIds) {
const provider = this.#providers.get(providerId);
const profile = registry.profiles.get(providerId); const profile = registry.profiles.get(providerId);
const provider = profile
? this.getProviderByProfile(providerId, profile)
: undefined;
if (!provider || !profile) continue; if (!provider || !profile) continue;
const normalizedCond = this.normalizeCond(providerId, cond); const normalizedCond = this.normalizeCond(registry, providerId, cond);
if ( if (
normalizedCond.modelId && normalizedCond.modelId &&
profile.models?.length && profile.models?.length &&
@@ -155,9 +283,6 @@ export class CopilotProviderFactory {
const matched = await provider.match(normalizedCond, execution); const matched = await provider.match(normalizedCond, execution);
if (!matched) continue; if (!matched) continue;
this.logger.debug(
`Copilot provider candidate found: ${provider.type} (${providerId})`
);
resolved.push({ resolved.push({
providerId, providerId,
provider, provider,
@@ -181,7 +306,11 @@ export class CopilotProviderFactory {
prefer?: CopilotProviderType; prefer?: CopilotProviderType;
} = {} } = {}
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
const routes = await this.resolveRoutes(cond, filter); const routes = await this.resolveRoutes(
cond,
filter,
this.getRequestContext(options)
);
return await this.prepareResolvedRoutes(routes, async route => { return await this.prepareResolvedRoutes(routes, async route => {
const prepared = await getProviderRuntimeHost( const prepared = await getProviderRuntimeHost(
route.provider route.provider
@@ -213,7 +342,11 @@ export class CopilotProviderFactory {
} = {}, } = {},
responseContract?: RequiredStructuredOutputContract responseContract?: RequiredStructuredOutputContract
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
const routes = await this.resolveRoutes(cond, filter); const routes = await this.resolveRoutes(
cond,
filter,
this.getRequestContext(options)
);
return await this.prepareResolvedRoutes(routes, async route => { return await this.prepareResolvedRoutes(routes, async route => {
const preparedStructured = const preparedStructured =
(await getProviderRuntimeHost(route.provider).prepare.structured( (await getProviderRuntimeHost(route.provider).prepare.structured(
@@ -239,10 +372,14 @@ export class CopilotProviderFactory {
input: string | string[], input: string | string[],
options: CopilotEmbeddingOptions = {} options: CopilotEmbeddingOptions = {}
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
const routes = await this.resolveRoutes({ const routes = await this.resolveRoutes(
modelId, { modelId, outputType: ModelOutputType.Embedding },
outputType: ModelOutputType.Embedding, {},
}); {
...this.getRequestContext(options),
featureKind: options?.featureKind ?? 'embedding',
}
);
return await this.prepareResolvedRoutes(routes, async route => { return await this.prepareResolvedRoutes(routes, async route => {
const preparedEmbedding = const preparedEmbedding =
(await getProviderRuntimeHost(route.provider).prepare.embedding( (await getProviderRuntimeHost(route.provider).prepare.embedding(
@@ -267,10 +404,14 @@ export class CopilotProviderFactory {
request: CopilotRerankRequest, request: CopilotRerankRequest,
options: CopilotChatOptions = {} options: CopilotChatOptions = {}
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
const routes = await this.resolveRoutes({ const routes = await this.resolveRoutes(
modelId, {
outputType: ModelOutputType.Rerank, modelId,
}); outputType: ModelOutputType.Rerank,
},
{},
{ ...this.getRequestContext(options), featureKind: 'rerank' }
);
return await this.prepareResolvedRoutes(routes, async route => { return await this.prepareResolvedRoutes(routes, async route => {
const preparedRerank = const preparedRerank =
(await getProviderRuntimeHost(route.provider).prepare.rerank( (await getProviderRuntimeHost(route.provider).prepare.rerank(
@@ -298,7 +439,10 @@ export class CopilotProviderFactory {
prefer?: CopilotProviderType; prefer?: CopilotProviderType;
} = {} } = {}
): Promise<ResolvedCopilotProvider[]> { ): Promise<ResolvedCopilotProvider[]> {
const routes = await this.resolveRoutes(cond, filter); const routes = await this.resolveRoutes(cond, filter, {
...this.getRequestContext(options),
featureKind: options?.featureKind ?? 'image',
});
return await this.prepareResolvedRoutes(routes, async route => { return await this.prepareResolvedRoutes(routes, async route => {
const preparedImage = const preparedImage =
(await getProviderRuntimeHost(route.provider).prepare.image( (await getProviderRuntimeHost(route.provider).prepare.image(

View File

@@ -8,7 +8,6 @@ export { FalProvider } from './fal';
export { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini'; export { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini';
export { CopilotProviderLifecycleService } from './lifecycle-service'; export { CopilotProviderLifecycleService } from './lifecycle-service';
export { OpenAIProvider } from './openai'; export { OpenAIProvider } from './openai';
export { PerplexityProvider } from './perplexity';
export type { CopilotProvider } from './provider'; export type { CopilotProvider } from './provider';
export { CopilotProviders } from './provider-tokens'; export { CopilotProviders } from './provider-tokens';
export { CopilotProviderRegistryService } from './registry-service'; export { CopilotProviderRegistryService } from './registry-service';

View File

@@ -1,60 +0,0 @@
import { CopilotProviderSideError, UserFriendlyError } from '../../../base';
import { type LlmBackendConfig } from '../../../native';
import { CopilotProvider } from './provider';
import {
type CopilotProviderExecution,
type ProviderDriverSpec,
} from './provider-runtime-contract';
import { CopilotProviderType, ModelOutputType } from './types';
export const DEFAULT_DIMENSIONS = 256;
export type MorphConfig = {
apiKey?: string;
};
export class MorphProvider extends CopilotProvider<MorphConfig> {
readonly type = CopilotProviderType.Morph;
protected resolveModelBackendKind() {
return 'morph' as const;
}
override configured(execution?: CopilotProviderExecution): boolean {
return !!this.getConfig(execution).apiKey;
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected morph response',
});
}
private createNativeConfig(
execution?: CopilotProviderExecution
): LlmBackendConfig {
return {
base_url: 'https://api.morphllm.com',
auth_token: this.getConfig(execution).apiKey ?? '',
};
}
override getDriverSpec(): ProviderDriverSpec {
return {
createBackendConfig: execution => this.createNativeConfig(execution),
mapError: error => this.handleError(error),
chat: {
resolveOutputType: kind =>
kind === 'streamObject' ? null : ModelOutputType.Text,
},
structured: false,
embedding: false,
rerank: false,
};
}
}

View File

@@ -14,7 +14,6 @@ import {
AttachmentAdmissionHost, AttachmentAdmissionHost,
} from '../runtime/hosts/attachment-admission'; } from '../runtime/hosts/attachment-admission';
import { AttachmentMaterializer } from '../runtime/hosts/attachment-materializer'; import { AttachmentMaterializer } from '../runtime/hosts/attachment-materializer';
import type { CopilotTool } from '../tools';
import { CopilotProvider } from './provider'; import { CopilotProvider } from './provider';
import { hasProviderModelBehaviorFlag } from './provider-model-runtime'; import { hasProviderModelBehaviorFlag } from './provider-model-runtime';
import type { import type {
@@ -22,7 +21,6 @@ import type {
ProviderDriverSpec, ProviderDriverSpec,
} from './provider-runtime-contract'; } from './provider-runtime-contract';
import { import {
CopilotChatTools,
CopilotProviderType, CopilotProviderType,
type PromptAttachment, type PromptAttachment,
type PromptMessage, type PromptMessage,
@@ -64,16 +62,6 @@ export class OpenAIProvider extends CopilotProvider<OpenAIConfig> {
}); });
} }
override getProviderSpecificTools(
toolName: CopilotChatTools,
_model: string
): [string, CopilotTool?] | undefined {
if (toolName === 'docEdit') {
return ['doc_edit', undefined];
}
return;
}
protected createNativeConfig( protected createNativeConfig(
execution?: CopilotProviderExecution execution?: CopilotProviderExecution
): LlmBackendConfig { ): LlmBackendConfig {

View File

@@ -1,75 +0,0 @@
import { CopilotProviderSideError } from '../../../base';
import { type LlmBackendConfig } from '../../../native';
import { CopilotProvider } from './provider';
import { hasProviderModelBehaviorFlag } from './provider-model-runtime';
import {
type CopilotProviderExecution,
type ProviderDriverSpec,
} from './provider-runtime-contract';
import { CopilotProviderType, ModelOutputType } from './types';
export type PerplexityConfig = {
apiKey: string;
endpoint?: string;
};
export class PerplexityProvider extends CopilotProvider<PerplexityConfig> {
readonly type = CopilotProviderType.Perplexity;
protected resolveModelBackendKind() {
return 'perplexity' as const;
}
override configured(execution?: CopilotProviderExecution): boolean {
return !!this.getConfig(execution).apiKey;
}
override getDriverSpec(): ProviderDriverSpec {
return {
createBackendConfig: execution => this.createNativeConfig(execution),
mapError: error => this.handleError(error),
chat: {
resolveOutputType: kind =>
kind === 'streamObject' ? null : ModelOutputType.Text,
withAttachment: false,
resolveRequestOptions: async context => ({
withAttachment: !hasProviderModelBehaviorFlag(
context.model,
'no_attachments'
),
include: hasProviderModelBehaviorFlag(
context.model,
'citations_include'
)
? ['citations']
: undefined,
}),
},
structured: false,
embedding: false,
rerank: false,
};
}
private createNativeConfig(
execution?: CopilotProviderExecution
): LlmBackendConfig {
const config = this.getConfig(execution);
const baseUrl = config.endpoint || 'https://api.perplexity.ai';
return {
base_url: baseUrl.replace(/\/v1\/?$/, ''),
auth_token: config.apiKey,
};
}
private handleError(e: any) {
if (e instanceof CopilotProviderSideError) {
return e;
}
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected perplexity response',
});
}
}

View File

@@ -21,18 +21,6 @@ const DEFAULT_MIDDLEWARE_BY_TYPE: Record<
[CopilotProviderType.AnthropicVertex]: { [CopilotProviderType.AnthropicVertex]: {
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE }, node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
}, },
[CopilotProviderType.Morph]: {
rust: {
request: ['clamp_max_tokens'],
},
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
},
[CopilotProviderType.Perplexity]: {
rust: {
request: ['clamp_max_tokens'],
},
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
},
[CopilotProviderType.Gemini]: { [CopilotProviderType.Gemini]: {
node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE }, node: { text: DEFAULT_NODE_TEXT_MIDDLEWARE },
}, },

View File

@@ -15,10 +15,8 @@ const LEGACY_PROVIDER_ORDER: CopilotProviderType[] = [
CopilotProviderType.FAL, CopilotProviderType.FAL,
CopilotProviderType.Gemini, CopilotProviderType.Gemini,
CopilotProviderType.GeminiVertex, CopilotProviderType.GeminiVertex,
CopilotProviderType.Perplexity,
CopilotProviderType.Anthropic, CopilotProviderType.Anthropic,
CopilotProviderType.AnthropicVertex, CopilotProviderType.AnthropicVertex,
CopilotProviderType.Morph,
]; ];
const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce( const LEGACY_PROVIDER_PRIORITY = LEGACY_PROVIDER_ORDER.reduce(

View File

@@ -5,9 +5,7 @@ import {
import { CloudflareWorkersAIProvider } from './cloudflare'; import { CloudflareWorkersAIProvider } from './cloudflare';
import { FalProvider } from './fal'; import { FalProvider } from './fal';
import { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini'; import { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini';
import { MorphProvider } from './morph';
import { OpenAIProvider } from './openai'; import { OpenAIProvider } from './openai';
import { PerplexityProvider } from './perplexity';
export const CopilotProviders = [ export const CopilotProviders = [
OpenAIProvider, OpenAIProvider,
@@ -15,8 +13,6 @@ export const CopilotProviders = [
FalProvider, FalProvider,
GeminiGenerativeProvider, GeminiGenerativeProvider,
GeminiVertexProvider, GeminiVertexProvider,
PerplexityProvider,
AnthropicOfficialProvider, AnthropicOfficialProvider,
AnthropicVertexProvider, AnthropicVertexProvider,
MorphProvider,
]; ];

View File

@@ -30,8 +30,6 @@ export enum CopilotProviderType {
Gemini = 'gemini', Gemini = 'gemini',
GeminiVertex = 'geminiVertex', GeminiVertex = 'geminiVertex',
OpenAI = 'openai', OpenAI = 'openai',
Perplexity = 'perplexity',
Morph = 'morph',
} }
export const CopilotProviderSchema = z.object({ export const CopilotProviderSchema = z.object({
@@ -80,8 +78,6 @@ export const PromptToolsSchema = z
'blobRead', 'blobRead',
'codeArtifact', 'codeArtifact',
'conversationSummary', 'conversationSummary',
// work with morph
'docEdit',
// work with indexer // work with indexer
'docRead', 'docRead',
'docCreate', 'docCreate',
@@ -268,6 +264,22 @@ const CopilotProviderOptionsSchema = z.object({
user: z.string().optional(), user: z.string().optional(),
session: z.string().optional(), session: z.string().optional(),
workspace: z.string().optional(), workspace: z.string().optional(),
byokLeaseId: z.string().optional(),
billingUnitId: z.string().optional(),
taskId: z.string().optional(),
actionId: z.string().optional(),
quotaBackedRoutesAllowed: z.boolean().optional(),
featureKind: z
.enum([
'chat',
'action',
'image',
'embedding',
'workspace_indexing',
'rerank',
'transcript',
])
.optional(),
}); });
export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge( export const CopilotChatOptionsSchema = CopilotProviderOptionsSchema.merge(

View File

@@ -164,11 +164,6 @@ export function toError(error: unknown): Error {
} }
} }
type DocEditFootnote = {
intent: string;
result: string;
};
function asRecord(value: unknown): Record<string, unknown> | null { function asRecord(value: unknown): Record<string, unknown> | null {
if (value && typeof value === 'object' && !Array.isArray(value)) { if (value && typeof value === 'object' && !Array.isArray(value)) {
return value as Record<string, unknown>; return value as Record<string, unknown>;
@@ -184,8 +179,6 @@ export class TextStreamParser {
private prefix: string | null = this.CALLOUT_PREFIX; private prefix: string | null = this.CALLOUT_PREFIX;
private readonly docEditFootnotes: DocEditFootnote[] = [];
public parse(chunk: CopilotTextStreamPart) { public parse(chunk: CopilotTextStreamPart) {
let result = ''; let result = '';
switch (chunk.type) { switch (chunk.type) {
@@ -233,13 +226,6 @@ export class TextStreamParser {
result += `\nWriting document "${chunk.input.title}"\n`; result += `\nWriting document "${chunk.input.title}"\n`;
break; break;
} }
case 'doc_edit': {
this.docEditFootnotes.push({
intent: String(chunk.input.instructions ?? ''),
result: '',
});
break;
}
} }
result = this.markAsCallout(result); result = this.markAsCallout(result);
break; break;
@@ -250,22 +236,6 @@ export class TextStreamParser {
); );
result = this.addPrefix(result); result = this.addPrefix(result);
switch (chunk.toolName) { switch (chunk.toolName) {
case 'doc_edit': {
const output = asRecord(chunk.output);
const array = output?.result;
if (Array.isArray(array)) {
result += array
.map(item => {
return `\n${String(asRecord(item)?.changedContent ?? '')}\n`;
})
.join('');
this.docEditFootnotes[this.docEditFootnotes.length - 1].result =
result;
} else {
this.docEditFootnotes.pop();
}
break;
}
case 'doc_semantic_search': { case 'doc_semantic_search': {
const output = chunk.output; const output = chunk.output;
if (Array.isArray(output)) { if (Array.isArray(output)) {
@@ -319,10 +289,7 @@ export class TextStreamParser {
} }
public end() { public end() {
const footnotes = this.docEditFootnotes.map((footnote, index) => { return '';
return `[^edit${index + 1}]: ${JSON.stringify({ type: 'doc-edit', ...footnote })}`;
});
return footnotes.join('\n');
} }
private addPrefix(text: string) { private addPrefix(text: string) {

View File

@@ -1,4 +1,4 @@
import { BadRequestException, NotFoundException } from '@nestjs/common'; import { NotFoundException } from '@nestjs/common';
import { import {
Args, Args,
Field, Field,
@@ -7,7 +7,6 @@ import {
Mutation, Mutation,
ObjectType, ObjectType,
Parent, Parent,
Query,
registerEnumType, registerEnumType,
ResolveField, ResolveField,
Resolver, Resolver,
@@ -19,7 +18,6 @@ import {
CallMetric, CallMetric,
CopilotDocNotFound, CopilotDocNotFound,
CopilotFailedToCreateMessage, CopilotFailedToCreateMessage,
CopilotProviderSideError,
CopilotSessionNotFound, CopilotSessionNotFound,
type FileUpload, type FileUpload,
paginate, paginate,
@@ -28,10 +26,8 @@ import {
RequestMutex, RequestMutex,
Throttle, Throttle,
TooManyRequest, TooManyRequest,
UserFriendlyError,
} from '../../base'; } from '../../base';
import { CurrentUser } from '../../core/auth'; import { CurrentUser } from '../../core/auth';
import { DocReader } from '../../core/doc';
import { AccessController, DocAction } from '../../core/permission'; import { AccessController, DocAction } from '../../core/permission';
import { UserType } from '../../core/user'; import { UserType } from '../../core/user';
import type { ListSessionOptions, UpdateChatSession } from '../../models'; import type { ListSessionOptions, UpdateChatSession } from '../../models';
@@ -40,7 +36,6 @@ import { ConversationInboxService } from './conversation/inbox';
import { PromptService } from './prompt/service'; import { PromptService } from './prompt/service';
import { CopilotProviderFactory } from './providers/factory'; import { CopilotProviderFactory } from './providers/factory';
import { ModelOutputType, type StreamObject } from './providers/types'; import { ModelOutputType, type StreamObject } from './providers/types';
import { CapabilityRuntime } from './runtime/capability-runtime';
import { ChatSessionService } from './session'; import { ChatSessionService } from './session';
import { type ChatHistory, type ChatMessage, SubmittedMessage } from './types'; import { type ChatHistory, type ChatMessage, SubmittedMessage } from './types';
@@ -376,9 +371,7 @@ export class CopilotResolver {
private readonly chatSession: ChatSessionService, private readonly chatSession: ChatSessionService,
private readonly historyProjector: CompatHistoryProjector, private readonly historyProjector: CompatHistoryProjector,
private readonly inbox: ConversationInboxService, private readonly inbox: ConversationInboxService,
private readonly docReader: DocReader, private readonly providerFactory: CopilotProviderFactory
private readonly providerFactory: CopilotProviderFactory,
private readonly runtime: CapabilityRuntime
) {} ) {}
@ResolveField(() => CopilotQuotaType, { @ResolveField(() => CopilotQuotaType, {
@@ -641,8 +634,6 @@ export class CopilotResolver {
throw new TooManyRequest('Server is busy'); throw new TooManyRequest('Server is busy');
} }
await this.chatSession.checkQuota(user.id);
return await this.chatSession.create({ return await this.chatSession.create({
...options, ...options,
pinned: options.pinned ?? false, pinned: options.pinned ?? false,
@@ -724,7 +715,6 @@ export class CopilotResolver {
throw new TooManyRequest('Server is busy'); throw new TooManyRequest('Server is busy');
} }
await this.chatSession.checkQuota(user.id);
return await this.chatSession.update({ return await this.chatSession.update({
...options, ...options,
userId: user.id, userId: user.id,
@@ -752,8 +742,6 @@ export class CopilotResolver {
throw new CopilotDocNotFound({ docId: options.docId }); throw new CopilotDocNotFound({ docId: options.docId });
} }
await this.chatSession.checkQuota(user.id);
return await this.chatSession.fork({ return await this.chatSession.fork({
...options, ...options,
userId: user.id, userId: user.id,
@@ -819,96 +807,6 @@ export class CopilotResolver {
} }
} }
@Query(() => String, {
description:
'Apply updates to a doc using LLM and return the merged markdown.',
deprecationReason: 'use Mutation.applyDocUpdates',
})
async applyDocUpdates(
@CurrentUser() user: CurrentUser,
@Args({ name: 'workspaceId', type: () => String })
workspaceId: string,
@Args({ name: 'docId', type: () => String })
docId: string,
@Args({ name: 'op', type: () => String })
op: string,
@Args({ name: 'updates', type: () => String })
updates: string
): Promise<string> {
return this.applyDocUpdatesInternal(user, workspaceId, docId, op, updates);
}
@Mutation(() => String, {
description:
'Apply updates to a doc using LLM and return the merged markdown.',
name: 'applyDocUpdates',
})
async applyDocUpdatesMutation(
@CurrentUser() user: CurrentUser,
@Args({ name: 'workspaceId', type: () => String })
workspaceId: string,
@Args({ name: 'docId', type: () => String })
docId: string,
@Args({ name: 'op', type: () => String })
op: string,
@Args({ name: 'updates', type: () => String })
updates: string
): Promise<string> {
return this.applyDocUpdatesInternal(user, workspaceId, docId, op, updates);
}
private async applyDocUpdatesInternal(
user: CurrentUser,
workspaceId: string,
docId: string,
op: string,
updates: string
): Promise<string> {
await this.assertPermission(user, { workspaceId, docId });
const docContent = await this.docReader.getDocMarkdown(
workspaceId,
docId,
true
);
if (!docContent || !docContent.markdown) {
throw new NotFoundException('Doc not found or empty');
}
const markdown = docContent.markdown.trim();
const resolved = await this.providerFactory.resolveProvider({
modelId: 'morph-v3-large',
outputType: ModelOutputType.Text,
});
if (!resolved) {
throw new BadRequestException('No LLM provider available');
}
try {
return await this.runtime.text(
{ modelId: 'morph-v3-large' },
[
{
role: 'user',
content: `<instruction>${op}</instruction>\n<code>${markdown}</code>\n<update>${updates}</update>`,
},
],
{ reasoning: false }
);
} catch (e: any) {
if (e instanceof UserFriendlyError) {
throw e;
} else {
throw new CopilotProviderSideError({
provider: resolved.provider.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected apply response',
});
}
}
}
private transformToSessionType(session: Omit<ChatHistory, 'messages'>) { private transformToSessionType(session: Omit<ChatHistory, 'messages'>) {
return { id: session.sessionId, ...session }; return { id: session.sessionId, ...session };
} }

View File

@@ -269,19 +269,20 @@ export class ActionRuntimeBridge {
attempt, attempt,
}); });
const inputWithBillingUnit = this.withBillingUnit(input, run.id);
let finalEvent: NativeActionEvent | undefined; let finalEvent: NativeActionEvent | undefined;
const attachments: unknown[] = []; const attachments: unknown[] = [];
try { try {
const nativeInput = await this.prepareNativeInput({ const nativeInput = await this.prepareNativeInput({
...input, ...inputWithBillingUnit,
}); });
for await (const event of this.runNativeStream( for await (const event of this.runNativeStream(
{ {
...nativeInput, ...nativeInput,
recipeId: input.actionId, recipeId: inputWithBillingUnit.actionId,
recipeVersion: input.actionVersion, recipeVersion: inputWithBillingUnit.actionVersion,
}, },
input.signal inputWithBillingUnit.signal
)) { )) {
finalEvent = event; finalEvent = event;
let projectedEvent = event; let projectedEvent = event;
@@ -343,4 +344,40 @@ export class ActionRuntimeBridge {
}); });
} }
} }
private withBillingUnit(
input: ActionRuntimeBridgeInput,
billingUnitId: string
): ActionRuntimeBridgeInput {
return {
...input,
prepareStructuredRoutes: input.prepareStructuredRoutes
? {
...input.prepareStructuredRoutes,
options: {
...input.prepareStructuredRoutes.options,
actionId:
input.prepareStructuredRoutes.options?.actionId ??
input.actionId,
billingUnitId:
input.prepareStructuredRoutes.options?.billingUnitId ??
billingUnitId,
},
}
: undefined,
prepareImageRoutes: input.prepareImageRoutes
? {
...input.prepareImageRoutes,
options: {
...input.prepareImageRoutes.options,
actionId:
input.prepareImageRoutes.options?.actionId ?? input.actionId,
billingUnitId:
input.prepareImageRoutes.options?.billingUnitId ??
billingUnitId,
},
}
: undefined,
};
}
} }

View File

@@ -411,6 +411,7 @@ function stripHostOnlyOptions<TOptions extends object | undefined>(
user: _user, user: _user,
session: _session, session: _session,
workspace: _workspace, workspace: _workspace,
quotaBackedRoutesAllowed: _quotaBackedRoutesAllowed,
...serializable ...serializable
} = options as Record<string, unknown>; } = options as Record<string, unknown>;

View File

@@ -90,6 +90,8 @@ export class ActionStreamHost {
prepared.session, prepared.session,
params, params,
userId, userId,
parsedQuery.byokLeaseId,
prepared.quotaBackedRoutesAllowed,
signal signal
); );
const runStream = this.bridge.runStream({ const runStream = this.bridge.runStream({
@@ -130,6 +132,9 @@ export class ActionStreamHost {
user: userId, user: userId,
workspace: prepared.session.config.workspaceId, workspace: prepared.session.config.workspaceId,
session: sessionId, session: sessionId,
byokLeaseId: parsedQuery.byokLeaseId,
quotaBackedRoutesAllowed: prepared.quotaBackedRoutesAllowed,
featureKind: 'action',
}, },
}, },
prepareImageRoutes: imageRoutes prepareImageRoutes: imageRoutes
@@ -177,6 +182,8 @@ export class ActionStreamHost {
session: ChatSession, session: ChatSession,
params: Record<string, unknown>, params: Record<string, unknown>,
userId: string, userId: string,
byokLeaseId?: string,
quotaBackedRoutesAllowed?: boolean,
signal?: AbortSignal signal?: AbortSignal
): Promise<ImageActionRoutePreparation | undefined> { ): Promise<ImageActionRoutePreparation | undefined> {
if (!isImageAction(actionId)) { if (!isImageAction(actionId)) {
@@ -201,6 +208,9 @@ export class ActionStreamHost {
user: userId, user: userId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
session: session.config.sessionId, session: session.config.sessionId,
byokLeaseId,
quotaBackedRoutesAllowed,
featureKind: 'image',
}, },
}; };
} }

View File

@@ -18,6 +18,10 @@ export type ChatSelectionOptions = {
reasoning?: boolean; reasoning?: boolean;
webSearch?: boolean; webSearch?: boolean;
toolsConfig?: ToolsConfig; toolsConfig?: ToolsConfig;
byokLeaseId?: string;
billingUnitId?: string;
featureKind?: 'chat' | 'action' | 'image';
quotaBackedRoutesAllowed?: boolean;
}; };
type ResolvePolicyModelInput = ResolveModelInput & { type ResolvePolicyModelInput = ResolveModelInput & {
@@ -97,6 +101,10 @@ export class CapabilityPolicyHost {
user: session.config.userId, user: session.config.userId,
session: session.config.sessionId, session: session.config.sessionId,
workspace: session.config.workspaceId, workspace: session.config.workspaceId,
byokLeaseId: options.byokLeaseId,
billingUnitId: options.billingUnitId,
featureKind: options.featureKind ?? 'chat',
quotaBackedRoutesAllowed: options.quotaBackedRoutesAllowed,
reasoning: options.reasoning, reasoning: options.reasoning,
webSearch: options.webSearch, webSearch: options.webSearch,
tools, tools,

View File

@@ -5,6 +5,7 @@ import {
CopilotSessionNotFound, CopilotSessionNotFound,
Mutex, Mutex,
} from '../../../../base'; } from '../../../../base';
import { CopilotAccessPolicy } from '../../access';
import { CompatSubmissionStore } from '../../compat/submission-store'; import { CompatSubmissionStore } from '../../compat/submission-store';
import { import {
canonicalizeTurnTrace, canonicalizeTurnTrace,
@@ -20,6 +21,12 @@ export type PreparedConversationTurn = {
params: Record<string, string>; params: Record<string, string>;
session: ChatSession; session: ChatSession;
latestTurn?: Turn; latestTurn?: Turn;
quotaBackedRoutesAllowed?: boolean;
};
type AppendedSessionMessage = {
turn?: Turn;
quotaBackedRoutesAllowed?: boolean;
}; };
@Injectable() @Injectable()
@@ -27,7 +34,8 @@ export class ConversationHost {
constructor( constructor(
private readonly sessions: ChatSessionService, private readonly sessions: ChatSessionService,
private readonly submissions: CompatSubmissionStore, private readonly submissions: CompatSubmissionStore,
private readonly mutex: Mutex private readonly mutex: Mutex,
private readonly access: CopilotAccessPolicy
) {} ) {}
private async loadAcceptedTurn( private async loadAcceptedTurn(
@@ -101,12 +109,32 @@ export class ConversationHost {
session: ChatSession, session: ChatSession,
sessionId: string, sessionId: string,
messageId?: string, messageId?: string,
retry = false retry = false,
): Promise<Turn | undefined> { byokLeaseId?: string
): Promise<AppendedSessionMessage> {
const resolveChatRouteAccess = () =>
this.access.resolveTurnRouteAccess({
userId,
workspaceId: session.config.workspaceId,
byokLeaseId,
featureKind: 'chat',
});
if (!messageId) { if (!messageId) {
await this.sessions.revertLatestMessage(sessionId, false); await this.sessions.revertLatestMessage(sessionId, false);
session.revertLatestMessage(false); session.revertLatestMessage(false);
return session.latestUserTurn; if (!session.latestUserTurn) {
const routeAccess = await resolveChatRouteAccess();
return {
turn: session.latestUserTurn,
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
};
}
const routeAccess = await resolveChatRouteAccess();
return {
turn: session.latestUserTurn,
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
};
} }
const acceptedTurn = await this.loadAcceptedTurn( const acceptedTurn = await this.loadAcceptedTurn(
@@ -116,7 +144,7 @@ export class ConversationHost {
retry retry
); );
if (acceptedTurn) { if (acceptedTurn) {
return acceptedTurn; return { turn: acceptedTurn, quotaBackedRoutesAllowed: true };
} }
await using lock = await this.mutex.acquire( await using lock = await this.mutex.acquire(
@@ -132,7 +160,9 @@ export class ConversationHost {
messageId, messageId,
retry retry
); );
if (acceptedAfterLock) return acceptedAfterLock; if (acceptedAfterLock) {
return { turn: acceptedAfterLock, quotaBackedRoutesAllowed: true };
}
const durableTurn = await this.loadDurableTurn( const durableTurn = await this.loadDurableTurn(
session, session,
@@ -140,9 +170,14 @@ export class ConversationHost {
messageId, messageId,
retry retry
); );
if (durableTurn) return durableTurn; if (durableTurn) {
return {
turn: durableTurn,
quotaBackedRoutesAllowed: true,
};
}
await this.sessions.checkQuota(userId); const routeAccess = await resolveChatRouteAccess();
const submission = await this.submissions.get(messageId); const submission = await this.submissions.get(messageId);
if (!submission || submission.sessionId !== sessionId) { if (!submission || submission.sessionId !== sessionId) {
@@ -176,7 +211,10 @@ export class ConversationHost {
turnId: turn.id ?? '', turnId: turn.id ?? '',
}); });
session.pushPersistedTurn(turn); session.pushPersistedTurn(turn);
return turn; return {
turn,
quotaBackedRoutesAllowed: routeAccess.quotaBackedRoutesAllowed,
};
} }
async prepareTurn( async prepareTurn(
@@ -184,27 +222,30 @@ export class ConversationHost {
sessionId: string, sessionId: string,
query: Record<string, string | string[]> query: Record<string, string | string[]>
): Promise<PreparedConversationTurn> { ): Promise<PreparedConversationTurn> {
const { messageId, retry, params } = ChatQuerySchema.parse(query); const { messageId, retry, params, byokLeaseId } =
ChatQuerySchema.parse(query);
const session = await this.sessions.get(sessionId); const session = await this.sessions.get(sessionId);
if (!session || session.config.userId !== userId) { if (!session || session.config.userId !== userId) {
throw new CopilotSessionNotFound(); throw new CopilotSessionNotFound();
} }
const latestMessage = await this.appendSessionMessage( const appended = await this.appendSessionMessage(
userId, userId,
session, session,
sessionId, sessionId,
messageId, messageId,
retry retry,
byokLeaseId
); );
const currentUserMessage = const currentUserMessage =
session.stashTurns.findLast(turn => turn.role === 'user') ?? session.stashTurns.findLast(turn => turn.role === 'user') ??
latestMessage; appended.turn;
return { return {
messageId, messageId,
params, params,
session, session,
latestTurn: currentUserMessage, latestTurn: currentUserMessage,
quotaBackedRoutesAllowed: appended.quotaBackedRoutesAllowed,
}; };
} }

View File

@@ -1,4 +1,4 @@
import { Injectable } from '@nestjs/common'; import { Injectable, Logger } from '@nestjs/common';
import { NoCopilotProviderAvailable } from '../../../base'; import { NoCopilotProviderAvailable } from '../../../base';
import { import {
@@ -13,6 +13,7 @@ import {
llmValidateJsonSchema, llmValidateJsonSchema,
parseNativeStructuredOutput, parseNativeStructuredOutput,
} from '../../../native'; } from '../../../native';
import { type ByokFeatureKind, ByokService } from '../byok';
import { type StreamObject } from '../providers/types'; import { type StreamObject } from '../providers/types';
import { CopilotExecutionMetrics } from './execution-metrics'; import { CopilotExecutionMetrics } from './execution-metrics';
import { import {
@@ -25,8 +26,11 @@ import { mapNativeSemanticError } from './native-errors';
import { import {
createNativeToolLoopAdapter, createNativeToolLoopAdapter,
NativeProviderAdapter, NativeProviderAdapter,
type NativeProviderAdapterOptions,
} from './tool/native-adapter'; } from './tool/native-adapter';
const logger = new Logger('NativeExecutionEngine');
function modelIdForError(modelId?: string) { function modelIdForError(modelId?: string) {
return modelId ?? 'auto'; return modelId ?? 'auto';
} }
@@ -60,6 +64,83 @@ function extractTextResponse(response: LlmDispatchResponse) {
.trim(); .trim();
} }
function getUsageContext(plan: ExecutionPlan) {
const options = 'options' in plan.request ? plan.request.options : undefined;
const requestFeatureKind =
plan.request.kind === 'text' ||
plan.request.kind === 'streamText' ||
plan.request.kind === 'streamObject'
? 'chat'
: plan.request.kind;
return {
workspaceId: options?.workspace,
userId: options?.user,
sessionId: options?.session,
taskId: options?.taskId,
actionId: options?.actionId,
billingUnitId: options?.billingUnitId,
featureKind: options?.featureKind ?? requestFeatureKind,
};
}
async function recordByokUsage(
byok: ByokService,
plan: ExecutionPlan,
input: {
providerId?: string;
model?: string | null;
usage?: LlmDispatchResponse['usage'];
}
) {
const context = getUsageContext(plan);
try {
await byok.recordUsage({
workspaceId: context.workspaceId,
userId: context.userId,
sessionId: context.sessionId,
taskId: context.taskId,
actionId: context.actionId,
billingUnitId: context.billingUnitId,
featureKind: context.featureKind as ByokFeatureKind,
providerId: input.providerId,
model: input.model,
usage: input.usage,
});
} catch (error) {
logger.warn(
`Failed to record BYOK usage: ${
error instanceof Error ? error.message : String(error)
}`
);
}
}
async function recordSingleByokRouteFailure(
byok: ByokService,
plan: ExecutionPlan,
error: unknown
) {
const [providerId] = plan.routePolicy.fallbackOrder;
if (plan.routePolicy.fallbackOrder.length !== 1 || !providerId) {
return;
}
const context = getUsageContext(plan);
try {
await byok.recordProviderFailure({
workspaceId: context.workspaceId,
providerId,
featureKind: context.featureKind as ByokFeatureKind,
error,
});
} catch (recordError) {
logger.warn(
`Failed to record BYOK provider failure: ${
recordError instanceof Error ? recordError.message : String(recordError)
}`
);
}
}
function recordPreparedDispatch( function recordPreparedDispatch(
executionMetrics: CopilotExecutionMetrics | undefined, executionMetrics: CopilotExecutionMetrics | undefined,
plan: ExecutionPlan, plan: ExecutionPlan,
@@ -72,7 +153,12 @@ function recordPreparedDispatch(
); );
} }
function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) { function createNativeChatAdapter(
dispatch: NativeChatDispatchPlan,
options?: {
onUsage?: NativeProviderAdapterOptions['onUsage'];
}
) {
if (dispatch.hasTools) { if (dispatch.hasTools) {
return createNativeToolLoopAdapter( return createNativeToolLoopAdapter(
{ preparedRoutes: dispatch.routes }, { preparedRoutes: dispatch.routes },
@@ -80,6 +166,7 @@ function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) {
{ {
maxSteps: dispatch.prepared.maxSteps, maxSteps: dispatch.prepared.maxSteps,
nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware, nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware,
onUsage: options?.onUsage,
} }
); );
} }
@@ -95,6 +182,7 @@ function createNativeChatAdapter(dispatch: NativeChatDispatchPlan) {
return new NativeProviderAdapter(nativeDispatch, { return new NativeProviderAdapter(nativeDispatch, {
nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware, nodeTextMiddleware: dispatch.prepared.postprocess?.nodeTextMiddleware,
onUsage: options?.onUsage,
}); });
} }
@@ -102,30 +190,38 @@ async function runPreparedValuePlan<TResult>(
plan: ExecutionPlan, plan: ExecutionPlan,
routeCount: number, routeCount: number,
executionMetrics: CopilotExecutionMetrics | undefined, executionMetrics: CopilotExecutionMetrics | undefined,
run: () => Promise<TResult> run: () => Promise<TResult>,
byok: ByokService
) { ) {
recordPreparedDispatch(executionMetrics, plan, routeCount); recordPreparedDispatch(executionMetrics, plan, routeCount);
try { try {
return await run(); return await run();
} catch (error) { } catch (error) {
throw mapNativeSemanticError(error); const mapped = mapNativeSemanticError(error);
await recordSingleByokRouteFailure(byok, plan, mapped);
throw mapped;
} }
} }
async function* mapPreparedStreamErrors<T>( async function* mapPreparedStreamErrors<T>(
source: AsyncIterable<T> source: AsyncIterable<T>,
plan: ExecutionPlan,
byok: ByokService
): AsyncIterableIterator<T> { ): AsyncIterableIterator<T> {
try { try {
yield* source; yield* source;
} catch (error) { } catch (error) {
throw mapNativeSemanticError(error); const mapped = mapNativeSemanticError(error);
await recordSingleByokRouteFailure(byok, plan, mapped);
throw mapped;
} }
} }
async function runChatValuePlan( async function runChatValuePlan(
plan: ExecutionPlan, plan: ExecutionPlan,
dispatch: NativeChatDispatchPlan, dispatch: NativeChatDispatchPlan,
executionMetrics?: CopilotExecutionMetrics executionMetrics: CopilotExecutionMetrics | undefined,
byok: ByokService
) { ) {
const adapter = createNativeChatAdapter(dispatch); const adapter = createNativeChatAdapter(dispatch);
return await runPreparedValuePlan( return await runPreparedValuePlan(
@@ -140,6 +236,11 @@ async function runChatValuePlan(
const result = await llmDispatchPlan({ const result = await llmDispatchPlan({
preparedRoutes: dispatch.routes, preparedRoutes: dispatch.routes,
}); });
await recordByokUsage(byok, plan, {
providerId: result.provider_id,
model: result.response.model,
usage: result.response.usage,
});
return extractTextResponse(result.response); return extractTextResponse(result.response);
} }
@@ -152,16 +253,26 @@ async function runChatValuePlan(
plan.hostContext.signal, plan.hostContext.signal,
plan.request.messages plan.request.messages
); );
} },
byok
); );
} }
async function* runChatStreamPlan( async function* runChatStreamPlan(
plan: ExecutionPlan, plan: ExecutionPlan,
dispatch: NativeChatDispatchPlan, dispatch: NativeChatDispatchPlan,
executionMetrics?: CopilotExecutionMetrics executionMetrics: CopilotExecutionMetrics | undefined,
byok: ByokService
): AsyncIterableIterator<string | StreamObject> { ): AsyncIterableIterator<string | StreamObject> {
const adapter = createNativeChatAdapter(dispatch); const adapter = createNativeChatAdapter(dispatch, {
onUsage: async usage => {
await recordByokUsage(byok, plan, {
providerId: usage.providerId,
model: usage.model,
usage: usage.usage,
});
},
});
recordPreparedDispatch(executionMetrics, plan, dispatch.routes.length); recordPreparedDispatch(executionMetrics, plan, dispatch.routes.length);
if (plan.request.kind === 'streamText') { if (plan.request.kind === 'streamText') {
@@ -170,7 +281,9 @@ async function* runChatStreamPlan(
dispatch.prepared.request, dispatch.prepared.request,
plan.hostContext.signal, plan.hostContext.signal,
plan.request.messages plan.request.messages
) ),
plan,
byok
); );
return; return;
} }
@@ -181,7 +294,9 @@ async function* runChatStreamPlan(
dispatch.prepared.request, dispatch.prepared.request,
plan.hostContext.signal, plan.hostContext.signal,
plan.request.messages plan.request.messages
) ),
plan,
byok
); );
return; return;
} }
@@ -192,7 +307,8 @@ async function* runChatStreamPlan(
async function* runPreparedImageArtifactPlan( async function* runPreparedImageArtifactPlan(
dispatch: NativeImageDispatchPlan, dispatch: NativeImageDispatchPlan,
plan: ExecutionPlan, plan: ExecutionPlan,
executionMetrics?: CopilotExecutionMetrics executionMetrics: CopilotExecutionMetrics | undefined,
byok: ByokService
): AsyncIterableIterator<NativeImageArtifact> { ): AsyncIterableIterator<NativeImageArtifact> {
if (plan.request.kind !== 'image') { if (plan.request.kind !== 'image') {
throw new Error('image dispatch requires image plan'); throw new Error('image dispatch requires image plan');
@@ -204,8 +320,21 @@ async function* runPreparedImageArtifactPlan(
result = await llmImageDispatchPlan({ result = await llmImageDispatchPlan({
preparedRoutes: dispatch.routes, preparedRoutes: dispatch.routes,
}); });
await recordByokUsage(byok, plan, {
providerId: result.provider_id,
model: dispatch.prepared.route.model,
usage: result.response.usage
? {
prompt_tokens: result.response.usage.input_tokens ?? 0,
completion_tokens: result.response.usage.output_tokens ?? 0,
total_tokens: result.response.usage.total_tokens ?? 0,
}
: undefined,
});
} catch (error) { } catch (error) {
throw mapNativeSemanticError(error); const mapped = mapNativeSemanticError(error);
await recordSingleByokRouteFailure(byok, plan, mapped);
throw mapped;
} }
for (const artifact of result.response.images) { for (const artifact of result.response.images) {
yield artifact; yield artifact;
@@ -214,13 +343,14 @@ async function* runPreparedImageArtifactPlan(
async function executePreparedPlan( async function executePreparedPlan(
plan: ExecutionPlan, plan: ExecutionPlan,
executionMetrics?: CopilotExecutionMetrics executionMetrics: CopilotExecutionMetrics | undefined,
byok: ByokService
): Promise<string | number[][] | number[] | null> { ): Promise<string | number[][] | number[] | null> {
switch (plan.request.kind) { switch (plan.request.kind) {
case 'text': { case 'text': {
const dispatch = plan.nativeDispatch?.chat; const dispatch = plan.nativeDispatch?.chat;
return dispatch return dispatch
? await runChatValuePlan(plan, dispatch, executionMetrics) ? await runChatValuePlan(plan, dispatch, executionMetrics, byok)
: null; : null;
} }
case 'structured': { case 'structured': {
@@ -236,13 +366,19 @@ async function executePreparedPlan(
const result = await llmStructuredDispatchPlan({ const result = await llmStructuredDispatchPlan({
preparedRoutes: dispatch.routes, preparedRoutes: dispatch.routes,
}); });
await recordByokUsage(byok, plan, {
providerId: result.provider_id,
model: result.response.model,
usage: result.response.usage,
});
const parsed = parseNativeStructuredOutput(result.response); const parsed = parseNativeStructuredOutput(result.response);
const validated = llmValidateJsonSchema( const validated = llmValidateJsonSchema(
dispatch.prepared.request.schema, dispatch.prepared.request.schema,
parsed parsed
); );
return JSON.stringify(validated); return JSON.stringify(validated);
} },
byok
); );
} }
case 'embedding': { case 'embedding': {
@@ -258,8 +394,20 @@ async function executePreparedPlan(
const result = await llmEmbeddingDispatchPlan({ const result = await llmEmbeddingDispatchPlan({
preparedRoutes: dispatch.routes, preparedRoutes: dispatch.routes,
}); });
await recordByokUsage(byok, plan, {
providerId: result.provider_id,
model: result.response.model,
usage: result.response.usage
? {
prompt_tokens: result.response.usage.prompt_tokens,
completion_tokens: 0,
total_tokens: result.response.usage.total_tokens,
}
: undefined,
});
return result.response.embeddings; return result.response.embeddings;
} },
byok
); );
} }
case 'rerank': { case 'rerank': {
@@ -275,8 +423,13 @@ async function executePreparedPlan(
const result = await llmRerankDispatchPlan({ const result = await llmRerankDispatchPlan({
preparedRoutes: dispatch.routes, preparedRoutes: dispatch.routes,
}); });
await recordByokUsage(byok, plan, {
providerId: result.provider_id,
model: result.response.model,
});
return result.response.scores; return result.response.scores;
} },
byok
); );
} }
default: default:
@@ -286,14 +439,15 @@ async function executePreparedPlan(
function executePreparedStreamPlan( function executePreparedStreamPlan(
plan: ExecutionPlan, plan: ExecutionPlan,
executionMetrics?: CopilotExecutionMetrics executionMetrics: CopilotExecutionMetrics | undefined,
byok: ByokService
): AsyncIterableIterator<string | StreamObject> | null { ): AsyncIterableIterator<string | StreamObject> | null {
switch (plan.request.kind) { switch (plan.request.kind) {
case 'streamText': case 'streamText':
case 'streamObject': { case 'streamObject': {
const dispatch = plan.nativeDispatch?.chat; const dispatch = plan.nativeDispatch?.chat;
return dispatch return dispatch
? runChatStreamPlan(plan, dispatch, executionMetrics) ? runChatStreamPlan(plan, dispatch, executionMetrics, byok)
: null; : null;
} }
default: default:
@@ -312,7 +466,10 @@ function noRouteStream<T>(plan: ExecutionPlan) {
@Injectable() @Injectable()
export class NativeExecutionEngine { export class NativeExecutionEngine {
constructor(private readonly executionMetrics?: CopilotExecutionMetrics) {} constructor(
private readonly byok: ByokService,
private readonly executionMetrics?: CopilotExecutionMetrics
) {}
private noRoute(plan: ExecutionPlan): never { private noRoute(plan: ExecutionPlan): never {
throw new NoCopilotProviderAvailable({ throw new NoCopilotProviderAvailable({
@@ -328,7 +485,11 @@ export class NativeExecutionEngine {
async execute( async execute(
plan: ExecutionPlanForKind<ValueExecutionKind> plan: ExecutionPlanForKind<ValueExecutionKind>
): Promise<string | number[][] | number[]> { ): Promise<string | number[][] | number[]> {
const result = await executePreparedPlan(plan, this.executionMetrics); const result = await executePreparedPlan(
plan,
this.executionMetrics,
this.byok
);
if (result === null) { if (result === null) {
return this.noRoute(plan); return this.noRoute(plan);
} }
@@ -345,7 +506,11 @@ export class NativeExecutionEngine {
executeStream( executeStream(
plan: ExecutionPlanForKind<StreamExecutionKind> plan: ExecutionPlanForKind<StreamExecutionKind>
): AsyncIterableIterator<string | StreamObject> { ): AsyncIterableIterator<string | StreamObject> {
const result = executePreparedStreamPlan(plan, this.executionMetrics); const result = executePreparedStreamPlan(
plan,
this.executionMetrics,
this.byok
);
if (result) { if (result) {
return result; return result;
} }
@@ -361,7 +526,8 @@ export class NativeExecutionEngine {
return runPreparedImageArtifactPlan( return runPreparedImageArtifactPlan(
dispatch, dispatch,
plan, plan,
this.executionMetrics this.executionMetrics,
this.byok
); );
} }

View File

@@ -13,7 +13,6 @@ import {
} from '../providers/types'; } from '../providers/types';
import { import {
buildBlobContentGetter, buildBlobContentGetter,
buildContentGetter,
buildDocContentGetter, buildDocContentGetter,
buildDocCreateHandler, buildDocCreateHandler,
buildDocKeywordSearchGetter, buildDocKeywordSearchGetter,
@@ -27,7 +26,6 @@ import {
createConversationSummaryTool, createConversationSummaryTool,
createDocComposeTool, createDocComposeTool,
createDocCreateTool, createDocCreateTool,
createDocEditTool,
createDocKeywordSearchTool, createDocKeywordSearchTool,
createDocReadTool, createDocReadTool,
createDocSemanticSearchTool, createDocSemanticSearchTool,
@@ -68,6 +66,21 @@ export class ToolRuntime {
if (!options?.tools?.length) { if (!options?.tools?.length) {
return tools; return tools;
} }
const runPromptText = (
promptName: string,
params: Record<string, unknown>
) =>
this.promptRuntime.runText(promptName, params, {
providerOptions: {
user: options.user,
session: options.session,
workspace: options.workspace,
byokLeaseId: options.byokLeaseId,
billingUnitId: options.billingUnitId,
quotaBackedRoutesAllowed: options.quotaBackedRoutesAllowed,
featureKind: options.featureKind,
},
});
for (const tool of options.tools) { for (const tool of options.tools) {
const toolDef = resolveProviderSpecificTool?.(tool, model); const toolDef = resolveProviderSpecificTool?.(tool, model);
@@ -97,23 +110,13 @@ export class ToolRuntime {
break; break;
} }
case 'codeArtifact': { case 'codeArtifact': {
tools.code_artifact = createCodeArtifactTool( tools.code_artifact = createCodeArtifactTool(runPromptText);
this.promptRuntime.runText.bind(this.promptRuntime)
);
break; break;
} }
case 'conversationSummary': { case 'conversationSummary': {
tools.conversation_summary = createConversationSummaryTool( tools.conversation_summary = createConversationSummaryTool(
options.session, options.session,
this.promptRuntime.runText.bind(this.promptRuntime) runPromptText
);
break;
}
case 'docEdit': {
const getDocContent = buildContentGetter(this.ac, this.docReader);
tools.doc_edit = createDocEditTool(
this.promptRuntime.runText.bind(this.promptRuntime),
getDocContent.bind(null, options)
); );
break; break;
} }
@@ -177,15 +180,11 @@ export class ToolRuntime {
break; break;
} }
case 'docCompose': { case 'docCompose': {
tools.doc_compose = createDocComposeTool( tools.doc_compose = createDocComposeTool(runPromptText);
this.promptRuntime.runText.bind(this.promptRuntime)
);
break; break;
} }
case 'sectionEdit': { case 'sectionEdit': {
tools.section_edit = createSectionEditTool( tools.section_edit = createSectionEditTool(runPromptText);
this.promptRuntime.runText.bind(this.promptRuntime)
);
break; break;
} }
} }

View File

@@ -1,3 +1,5 @@
import { Logger } from '@nestjs/common';
import type { LlmRequest, LlmToolLoopStreamEvent } from '../../../../native'; import type { LlmRequest, LlmToolLoopStreamEvent } from '../../../../native';
import type { NodeTextMiddleware } from '../../config'; import type { NodeTextMiddleware } from '../../config';
import type { PromptMessage, StreamObject } from '../../providers/types'; import type { PromptMessage, StreamObject } from '../../providers/types';
@@ -20,9 +22,14 @@ type AttachmentFootnote = {
fileType: string; fileType: string;
}; };
type NativeProviderAdapterOptions = { export type NativeProviderAdapterOptions = {
maxSteps?: number; maxSteps?: number;
nodeTextMiddleware?: NodeTextMiddleware[]; nodeTextMiddleware?: NodeTextMiddleware[];
onUsage?: (input: {
providerId: string;
model?: string;
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
}) => void | Promise<void>;
}; };
type NativeStreamDispatch = ConstructorParameters< type NativeStreamDispatch = ConstructorParameters<
@@ -103,9 +110,11 @@ function formatAttachmentFootnotes(
} }
export class NativeProviderAdapter { export class NativeProviderAdapter {
readonly logger = new Logger(NativeProviderAdapter.name);
readonly #runtime: NativeRuntimeAdapter; readonly #runtime: NativeRuntimeAdapter;
readonly #enableCallout: boolean; readonly #enableCallout: boolean;
readonly #enableCitationFootnote: boolean; readonly #enableCitationFootnote: boolean;
readonly #onUsage?: NativeProviderAdapterOptions['onUsage'];
constructor( constructor(
dispatchWithTools: NativeStreamDispatch, dispatchWithTools: NativeStreamDispatch,
@@ -120,6 +129,36 @@ export class NativeProviderAdapter {
enabledNodeTextMiddlewares.has('thinking_format'); enabledNodeTextMiddlewares.has('thinking_format');
this.#enableCitationFootnote = this.#enableCitationFootnote =
enabledNodeTextMiddlewares.has('citation_footnote'); enabledNodeTextMiddlewares.has('citation_footnote');
this.#onUsage = options.onUsage;
}
async #recordUsageOnProviderSelected(
event: { type: string; [key: string]: unknown },
state: {
model?: string;
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
}
) {
if (
event.type !== 'provider_selected' ||
typeof event.provider_id !== 'string'
) {
return;
}
try {
await this.#onUsage?.({
providerId: event.provider_id,
model: state.model,
usage: state.usage,
});
} catch (error) {
this.logger.warn(
`Provider usage callback failed: ${
error instanceof Error ? error.message : String(error)
}`
);
}
state.usage = undefined;
} }
async text( async text(
@@ -144,6 +183,10 @@ export class NativeProviderAdapter {
? new CitationFootnoteFormatter() ? new CitationFootnoteFormatter()
: null; : null;
let streamPartId = 0; let streamPartId = 0;
const usageState: {
model?: string;
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
} = {};
for await (const event of this.#runtime.streamEvents( for await (const event of this.#runtime.streamEvents(
request, request,
@@ -151,6 +194,22 @@ export class NativeProviderAdapter {
messages messages
)) { )) {
switch (event.type) { switch (event.type) {
case 'message_start': {
const startEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'message_start' }
>;
usageState.model = startEvent.model;
break;
}
case 'usage': {
const usageEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'usage' }
>;
usageState.usage = usageEvent.usage;
break;
}
case 'text_delta': { case 'text_delta': {
const textEvent = event as unknown as { text: string }; const textEvent = event as unknown as { text: string };
if (textParser) { if (textParser) {
@@ -216,6 +275,11 @@ export class NativeProviderAdapter {
break; break;
} }
case 'done': { case 'done': {
const doneEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'done' }
>;
usageState.usage = doneEvent.usage ?? usageState.usage;
const footnotes = textParser?.end() ?? ''; const footnotes = textParser?.end() ?? '';
const citations = citationFormatter?.end() ?? ''; const citations = citationFormatter?.end() ?? '';
const tails = [citations, footnotes].filter(Boolean).join('\n'); const tails = [citations, footnotes].filter(Boolean).join('\n');
@@ -224,6 +288,9 @@ export class NativeProviderAdapter {
} }
break; break;
} }
case 'provider_selected':
await this.#recordUsageOnProviderSelected(event, usageState);
break;
case 'error': case 'error':
throw new Error( throw new Error(
typeof event.message === 'string' typeof event.message === 'string'
@@ -246,6 +313,10 @@ export class NativeProviderAdapter {
: null; : null;
const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>(); const fallbackAttachmentFootnotes = new Map<string, AttachmentFootnote>();
let hasFootnoteReference = false; let hasFootnoteReference = false;
const usageState: {
model?: string;
usage?: Extract<LlmToolLoopStreamEvent, { type: 'usage' }>['usage'];
} = {};
for await (const event of this.#runtime.streamEvents( for await (const event of this.#runtime.streamEvents(
request, request,
@@ -253,6 +324,22 @@ export class NativeProviderAdapter {
messages messages
)) { )) {
switch (event.type) { switch (event.type) {
case 'message_start': {
const startEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'message_start' }
>;
usageState.model = startEvent.model;
break;
}
case 'usage': {
const usageEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'usage' }
>;
usageState.usage = usageEvent.usage;
break;
}
case 'text_delta': { case 'text_delta': {
const textEvent = event as unknown as { text: string }; const textEvent = event as unknown as { text: string };
if (textEvent.text.includes('[^')) { if (textEvent.text.includes('[^')) {
@@ -302,6 +389,11 @@ export class NativeProviderAdapter {
break; break;
} }
case 'done': { case 'done': {
const doneEvent = event as Extract<
LlmToolLoopStreamEvent,
{ type: 'done' }
>;
usageState.usage = doneEvent.usage ?? usageState.usage;
const citations = citationFormatter?.end() ?? ''; const citations = citationFormatter?.end() ?? '';
if (citations) { if (citations) {
hasFootnoteReference = true; hasFootnoteReference = true;
@@ -318,6 +410,9 @@ export class NativeProviderAdapter {
} }
break; break;
} }
case 'provider_selected':
await this.#recordUsageOnProviderSelected(event, usageState);
break;
case 'error': case 'error':
throw new Error( throw new Error(
typeof event.message === 'string' typeof event.message === 'string'

View File

@@ -62,7 +62,7 @@ export class TurnOrchestrator {
sessionId, sessionId,
query query
); );
const { modelId, reasoning, webSearch, toolsConfig } = const { modelId, reasoning, webSearch, toolsConfig, byokLeaseId } =
ChatQuerySchema.parse(query); ChatQuerySchema.parse(query);
const promptParams = await this.buildPromptParams(sessionId, { const promptParams = await this.buildPromptParams(sessionId, {
latestTurn: prepared.latestTurn, latestTurn: prepared.latestTurn,
@@ -82,6 +82,15 @@ export class TurnOrchestrator {
reasoning, reasoning,
webSearch, webSearch,
toolsConfig, toolsConfig,
byokLeaseId,
billingUnitId: prepared.latestTurn?.id,
quotaBackedRoutesAllowed: prepared.quotaBackedRoutesAllowed,
featureKind:
selection.responseMode === 'image'
? 'image'
: selection.responseMode === 'object'
? 'action'
: 'chat',
}), }),
}; };
} }

View File

@@ -20,6 +20,7 @@ import {
type UpdateChatSession, type UpdateChatSession,
UpdateChatSessionOptions, UpdateChatSessionOptions,
} from '../../models'; } from '../../models';
import { CopilotAccessPolicy } from './access';
import { ConversationPolicy } from './conversation/policy'; import { ConversationPolicy } from './conversation/policy';
import { ConversationStore } from './conversation/store'; import { ConversationStore } from './conversation/store';
import { type Conversation, promptMessageFromTurn, type Turn } from './core'; import { type Conversation, promptMessageFromTurn, type Turn } from './core';
@@ -186,6 +187,7 @@ export class ChatSessionService {
private readonly models: Models, private readonly models: Models,
private readonly jobs: JobQueue, private readonly jobs: JobQueue,
private readonly store: ConversationStore, private readonly store: ConversationStore,
private readonly access: CopilotAccessPolicy,
private readonly conversationPolicy: ConversationPolicy, private readonly conversationPolicy: ConversationPolicy,
private readonly prompts: PromptService, private readonly prompts: PromptService,
private readonly promptRuntime: PromptRuntime private readonly promptRuntime: PromptRuntime
@@ -298,11 +300,11 @@ export class ChatSessionService {
} }
async getQuota(userId: string) { async getQuota(userId: string) {
return await this.conversationPolicy.getQuota(userId); return await this.access.getQuota(userId);
} }
async checkQuota(userId: string) { async checkQuota(userId: string) {
await this.conversationPolicy.checkQuota(userId); await this.access.checkQuota(userId);
} }
async create(options: ChatSessionOptions): Promise<string> { async create(options: ChatSessionOptions): Promise<string> {

View File

@@ -1,213 +0,0 @@
import { z } from 'zod';
import { DocReader } from '../../../core/doc';
import { AccessController } from '../../../core/permission';
import { defineTool } from './tool';
import type { CopilotChatOptions } from './types';
type RunPromptText = (
promptName: string,
params: Record<string, unknown>
) => Promise<string>;
const CodeEditSchema = z
.array(
z.object({
op: z
.string()
.describe(
'A short description of the change, such as "Bold intro name"'
),
updates: z
.string()
.describe(
'Markdown block fragments that represent the change, including the block_id and type'
),
})
)
.describe(
'An array of independent semantic changes to apply to the document.'
);
export const buildContentGetter = (ac: AccessController, doc: DocReader) => {
const getDocContent = async (options: CopilotChatOptions, docId?: string) => {
if (!options || !docId || !options.user || !options.workspace) {
return undefined;
}
const canAccess = await ac
.user(options.user)
.workspace(options.workspace)
.doc(docId)
.can('Doc.Read');
if (!canAccess) return undefined;
const content = await doc.getDocMarkdown(options.workspace, docId, true);
return content?.markdown.trim() || undefined;
};
return getDocContent;
};
export const createDocEditTool = (
prompt: RunPromptText,
getContent: (targetId?: string) => Promise<string | undefined>
) => {
return defineTool({
description: `
Use this tool to propose an edit to a structured Markdown document with identifiable blocks.
Each block begins with a comment like <!-- block_id=... -->, and represents a unit of editable content such as a heading, paragraph, list, or code snippet.
This will be read by a less intelligent model, which will quickly apply the edit. You should make it clear what the edit is, while also minimizing the unchanged code you write.
If you receive a markdown without block_id comments, you should call \`doc_read\` tool to get the content.
Your task is to return a list of block-level changes needed to fulfill the user's intent. **Each change in code_edit must be completely independent: each code_edit entry should only perform a single, isolated change, and must not include the effects of other changes. For example, the updates for a delete operation should only show the context related to the deletion, and must not include any content modified by other operations (such as bolding or insertion). This ensures that each change can be applied independently and in any order.**
Each change should correspond to a specific user instruction and be represented by one of the following operations:
replace: Replace the content of a block with updated Markdown.
delete: Remove a block entirely.
insert: Add a new block, and specify its block_id and content.
Important Instructions:
- Use the existing block structure as-is. Do not reformat or reorder blocks unless explicitly asked.
- When inserting, follow the same format as a replacement, but ensure the new block_id does not conflict with existing IDs.
- When replacing content, always keep the original block_id unchanged.
- When deleting content, only use the format <!-- delete block_id=xxx -->, and only for valid block_id present in the original <code> content.
- Each top-level list item should be a block. Like this:
\`\`\`markdown
<!-- block_id=001 flavour=affine:list -->
* Item 1
* SubItem 1
<!-- block_id=002 flavour=affine:list -->
1. Item 1
1. SubItem 1
\`\`\`
- Your task is to return a list of block-level changes needed to fulfill the user's intent.
- **Each change in code_edit must be completely independent: each code_edit entry should only perform a single, isolated change, and must not include the effects of other changes. For example, the updates for a delete operation should only show the context related to the deletion, and must not include any content modified by other operations (such as bolding or insertion). This ensures that each change can be applied independently and in any order.**
Original Content:
\`\`\`markdown
<!-- block_id=001 flavour=paragraph -->
# Andriy Shevchenko
<!-- block_id=002 flavour=paragraph -->
## Player Profile
<!-- block_id=003 flavour=paragraph -->
Andriy Shevchenko is a legendary Ukrainian striker, best known for his time at AC Milan and Dynamo Kyiv. He won the Ballon d'Or in 2004.
<!-- block_id=004 flavour=paragraph -->
## Career Overview
<!-- block_id=005 flavour=list -->
- Born in 1976 in Ukraine.
<!-- block_id=006 flavour=list -->
- Rose to fame at Dynamo Kyiv in the 1990s.
<!-- block_id=007 flavour=list -->
- Starred at AC Milan (19992006), scoring over 170 goals.
<!-- block_id=008 flavour=list -->
- Played for Chelsea (20062009) before returning to Kyiv.
<!-- block_id=009 flavour=list -->
- Coached Ukraine national team, reaching Euro 2020 quarter-finals.
\`\`\`
User Request
\`\`\`
Bold the players name in the intro, add a summary section at the end, and remove the career overview.
\`\`\`
Example response:
\`\`\`json
[
{
"op": "Bold the player's name in the introduction",
"updates": "
<!-- block_id=003 flavour=paragraph -->
**Andriy Shevchenko** is a legendary Ukrainian striker, best known for his time at AC Milan and Dynamo Kyiv. He won the Ballon d'Or in 2004.
"
},
{
"op": "Add a summary section at the end",
"updates": "
<!-- block_id=new-abc123 flavour=paragraph -->
## Summary
<!-- block_id=new-def456 flavour=paragraph -->
Shevchenko is celebrated as one of the greatest Ukrainian footballers of all time. Known for his composure, strength, and goal-scoring instinct, he left a lasting legacy both on and off the pitch.
"
},
{
"op": "Delete the career overview section",
"updates": "
<!-- delete block_id=004 -->
<!-- delete block_id=005 -->
<!-- delete block_id=006 -->
<!-- delete block_id=007 -->
<!-- delete block_id=008 -->
<!-- delete block_id=009 -->
"
}
]
\`\`\`
You should specify the following arguments before the others: [doc_id], [origin_content]
`,
inputSchema: z.object({
doc_id: z
.string()
.describe(
'The unique ID of the document being edited. Required when editing an existing document stored in the system. If you are editing ad-hoc Markdown content instead, leave this empty and use origin_content.'
)
.optional(),
origin_content: z
.string()
.describe(
'The full original Markdown content, including all block_id comments (e.g., <!-- block_id=block-001 type=paragraph -->). Required when doc_id is not provided. This content will be parsed into discrete editable blocks.'
)
.optional(),
instructions: z
.string()
.describe(
'A short, first-person description of the intended edit, clearly summarizing what I will change. For example: "I will translate the steps into English and delete the paragraph explaining the delay." This helps the downstream system understand the purpose of the changes.'
),
code_edit: z.preprocess(val => {
// BACKGROUND: LLM sometimes returns a JSON string instead of an array.
if (typeof val === 'string') {
return JSON.parse(val);
}
return val;
}, CodeEditSchema) as unknown as typeof CodeEditSchema,
}),
execute: async ({ doc_id, origin_content, code_edit }) => {
try {
const content = origin_content || (await getContent(doc_id));
if (!content) {
return 'Doc not found or doc is empty';
}
const changedContents = await Promise.all(
code_edit.map(async edit => {
return await prompt('Apply Updates', {
content,
op: edit.op,
updates: edit.updates,
});
})
);
return {
result: changedContents.map((changedContent, index) => ({
op: code_edit[index].op,
updates: code_edit[index].updates,
originalContent: content,
changedContent,
})),
};
} catch {
return 'Failed to apply edit to the doc';
}
},
});
};

View File

@@ -13,6 +13,11 @@ import { toolError } from './error';
import { defineTool } from './tool'; import { defineTool } from './tool';
import type { CopilotChatOptions } from './types'; import type { CopilotChatOptions } from './types';
const getEmbeddingRouteContext = (options: CopilotChatOptions) => ({
userId: options?.user,
byokLeaseId: options?.byokLeaseId,
});
export const buildDocSearchGetter = ( export const buildDocSearchGetter = (
ac: AccessController, ac: AccessController,
context: CopilotContextService, context: CopilotContextService,
@@ -43,12 +48,32 @@ export const buildDocSearchGetter = (
'Doc Semantic Search Failed', 'Doc Semantic Search Failed',
'You do not have permission to access this workspace.' 'You do not have permission to access this workspace.'
); );
const routeContext = getEmbeddingRouteContext(options);
const [chunks, contextChunks] = await Promise.all([ const [chunks, contextChunks] = await Promise.all([
context.matchWorkspaceAll(options.workspace, query, 10, signal), context.matchWorkspaceAll(
options.workspace,
query,
10,
signal,
0.8,
undefined,
0.85,
routeContext
),
sessionId sessionId
? context ? context
.getBySessionId(sessionId) .getBySessionId(sessionId)
.then(current => current?.matchFiles(query, 10, signal) ?? []) .then(
current =>
current?.matchFiles(
query,
10,
signal,
0.85,
0.5,
routeContext
) ?? []
)
: [], : [],
]); ]);

View File

@@ -2,7 +2,6 @@ export * from './blob-read';
export * from './code-artifact'; export * from './code-artifact';
export * from './conversation-summary'; export * from './conversation-summary';
export * from './doc-compose'; export * from './doc-compose';
export * from './doc-edit';
export * from './doc-keyword-search'; export * from './doc-keyword-search';
export * from './doc-read'; export * from './doc-read';
export * from './doc-semantic-search'; export * from './doc-semantic-search';

View File

@@ -10,7 +10,7 @@ import {
sniffMime, sniffMime,
} from '../../../base'; } from '../../../base';
import { Models } from '../../../models'; import { Models } from '../../../models';
import { ConversationPolicy } from '../conversation/policy'; import { CopilotAccessPolicy } from '../access';
import { PromptService } from '../prompt'; import { PromptService } from '../prompt';
import { CopilotProviderType } from '../providers/types'; import { CopilotProviderType } from '../providers/types';
import { ActionRuntimeBridge } from '../runtime/action-runtime-bridge'; import { ActionRuntimeBridge } from '../runtime/action-runtime-bridge';
@@ -62,7 +62,7 @@ export class CopilotTranscriptionService {
private readonly tasks: TaskPolicy, private readonly tasks: TaskPolicy,
private readonly prompts: PromptService, private readonly prompts: PromptService,
private readonly actionBridge: ActionRuntimeBridge, private readonly actionBridge: ActionRuntimeBridge,
@Optional() private readonly conversationPolicy?: ConversationPolicy @Optional() private readonly access?: CopilotAccessPolicy
) {} ) {}
private parseTaskPayload(payload: unknown): TranscriptionPayloadV2 { private parseTaskPayload(payload: unknown): TranscriptionPayloadV2 {
@@ -223,6 +223,12 @@ export class CopilotTranscriptionService {
throw new CopilotTranscriptionJobExists(); throw new CopilotTranscriptionJobExists();
} }
await this.access?.assertQuotaOrByok({
userId,
workspaceId,
featureKind: 'transcript',
});
const { model, strategy } = await this.resolveTranscriptStrategy( const { model, strategy } = await this.resolveTranscriptStrategy(
userId, userId,
input?.strategy ?? undefined input?.strategy ?? undefined
@@ -270,6 +276,12 @@ export class CopilotTranscriptionService {
); );
} }
await this.access?.assertQuotaOrByok({
userId,
workspaceId,
featureKind: 'transcript',
});
const payload = this.parseTaskPayload(task.protectedResult); const payload = this.parseTaskPayload(task.protectedResult);
const { model } = await this.resolveTranscriptStrategy( const { model } = await this.resolveTranscriptStrategy(
userId, userId,
@@ -307,13 +319,17 @@ export class CopilotTranscriptionService {
return null; return null;
} }
const settled = if (task.status === 'settled') {
task.status === 'settled' return this.taskToJob(task);
? task }
: await (async () => {
await this.conversationPolicy?.checkQuota(userId); await this.access?.assertQuotaOrByok({
return await this.models.copilotTranscriptTask.settle(task.id); userId,
})(); workspaceId,
featureKind: 'transcript',
});
const settled = await this.models.copilotTranscriptTask.settle(task.id);
return this.taskToJob(settled); return this.taskToJob(settled);
} }
@@ -378,6 +394,13 @@ export class CopilotTranscriptionService {
stepId: 'transcribe', stepId: 'transcribe',
modelId, modelId,
messages, messages,
options: {
user: task.userId,
workspace: task.workspaceId,
taskId,
billingUnitId: taskId,
featureKind: 'transcript',
},
prefer: CopilotProviderType.Gemini, prefer: CopilotProviderType.Gemini,
responseContract: TranscriptActionResultContract, responseContract: TranscriptActionResultContract,
}, },

View File

@@ -37,6 +37,7 @@ export const ChatQuerySchema = z
.object({ .object({
messageId: zMaybeString, messageId: zMaybeString,
modelId: zMaybeString, modelId: zMaybeString,
byokLeaseId: zMaybeString,
retry: zBool, retry: zBool,
reasoning: zBool, reasoning: zBool,
webSearch: zBool, webSearch: zBool,
@@ -47,6 +48,7 @@ export const ChatQuerySchema = z
({ ({
messageId, messageId,
modelId, modelId,
byokLeaseId,
retry, retry,
reasoning, reasoning,
webSearch, webSearch,
@@ -55,6 +57,7 @@ export const ChatQuerySchema = z
}) => ({ }) => ({
messageId, messageId,
modelId, modelId,
byokLeaseId,
retry, retry,
reasoning, reasoning,
webSearch, webSearch,

View File

@@ -288,6 +288,24 @@ type BlobUploadedPart {
partNumber: Int! partNumber: Int!
} }
enum ByokKeyStorage {
local
server
}
enum ByokKeyTestStatus {
failed
passed
untested
}
enum ByokProvider {
anthropic
fal
gemini
openai
}
type CalendarAccountObjectType { type CalendarAccountObjectType {
calendars: [CalendarSubscriptionObjectType!]! calendars: [CalendarSubscriptionObjectType!]!
calendarsCount: Int! calendarsCount: Int!
@@ -731,6 +749,26 @@ input CreateUserInput {
password: String password: String
} }
input CreateWorkspaceByokLocalLeaseInput {
providers: [CreateWorkspaceByokLocalLeaseProviderInput!]!
workspaceId: String!
}
input CreateWorkspaceByokLocalLeaseProviderInput {
apiKey: String!
description: String
enabled: Boolean
endpoint: String
name: String!
provider: ByokProvider!
sortOrder: SafeInt
}
type CreateWorkspaceByokLocalLeaseResultType {
expiresAt: DateTime!
leaseId: String!
}
type CredentialsRequirementType { type CredentialsRequirementType {
password: PasswordLimitsType! password: PasswordLimitsType!
} }
@@ -1514,9 +1552,6 @@ type Mutation {
"""Update workspace flags and features for admin""" """Update workspace flags and features for admin"""
adminUpdateWorkspace(input: AdminUpdateWorkspaceInput!): AdminWorkspace adminUpdateWorkspace(input: AdminUpdateWorkspaceInput!): AdminWorkspace
"""Apply updates to a doc using LLM and return the merged markdown."""
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String!
approveMember(userId: String!, workspaceId: String!): Boolean! approveMember(userId: String!, workspaceId: String!): Boolean!
"""Ban an user""" """Ban an user"""
@@ -1527,6 +1562,7 @@ type Mutation {
"""Cleanup sessions""" """Cleanup sessions"""
cleanupCopilotSession(options: DeleteSessionInput!): [String!]! cleanupCopilotSession(options: DeleteSessionInput!): [String!]!
clearWorkspaceByokConfigs(provider: ByokProvider, workspaceId: String!): Boolean!
completeBlobUpload(key: String!, parts: [BlobUploadPartInput!], uploadId: String, workspaceId: String!): String! completeBlobUpload(key: String!, parts: [BlobUploadPartInput!], uploadId: String, workspaceId: String!): String!
createBlobUpload(key: String!, mime: String!, size: Int!, workspaceId: String!): BlobUploadInit! createBlobUpload(key: String!, mime: String!, size: Int!, workspaceId: String!): BlobUploadInit!
@@ -1560,6 +1596,7 @@ type Mutation {
"""Create a new workspace""" """Create a new workspace"""
createWorkspace(init: Upload): WorkspaceType! createWorkspace(init: Upload): WorkspaceType!
createWorkspaceByokLocalLease(input: CreateWorkspaceByokLocalLeaseInput!): CreateWorkspaceByokLocalLeaseResultType!
deactivateLicense(workspaceId: String!): Boolean! deactivateLicense(workspaceId: String!): Boolean!
deleteAccount: DeleteAccount! deleteAccount: DeleteAccount!
deleteBlob(hash: String @deprecated(reason: "use parameter [key]"), key: String, permanently: Boolean! = false, workspaceId: String!): Boolean! deleteBlob(hash: String @deprecated(reason: "use parameter [key]"), key: String, permanently: Boolean! = false, workspaceId: String!): Boolean!
@@ -1573,6 +1610,7 @@ type Mutation {
"""Delete a user account""" """Delete a user account"""
deleteUser(id: String!): DeleteAccount! deleteUser(id: String!): DeleteAccount!
deleteWorkspace(id: String!): Boolean! deleteWorkspace(id: String!): Boolean!
deleteWorkspaceByokConfig(id: ID!, workspaceId: String!): Boolean!
"""Reenable an banned user""" """Reenable an banned user"""
enableUser(id: String!): UserType! enableUser(id: String!): UserType!
@@ -1628,6 +1666,7 @@ type Mutation {
"""Remove workspace embedding files""" """Remove workspace embedding files"""
removeWorkspaceEmbeddingFiles(fileId: String!, workspaceId: String!): Boolean! removeWorkspaceEmbeddingFiles(fileId: String!, workspaceId: String!): Boolean!
removeWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Boolean! removeWorkspaceFeature(feature: FeatureType!, workspaceId: String!): Boolean!
reorderWorkspaceByokConfigs(input: ReorderWorkspaceByokConfigsInput!): [WorkspaceByokKeyConfigType!]!
"""Request to apply the subscription in advance""" """Request to apply the subscription in advance"""
requestApplySubscription(transactionId: String!): [SubscriptionType!]! requestApplySubscription(transactionId: String!): [SubscriptionType!]!
@@ -1650,6 +1689,7 @@ type Mutation {
setBlob(blob: Upload!, workspaceId: String!): String! setBlob(blob: Upload!, workspaceId: String!): String!
settleTranscriptTask(taskId: String!, workspaceId: String!): TranscriptionResultType settleTranscriptTask(taskId: String!, workspaceId: String!): TranscriptionResultType
submitTranscriptTask(blob: Upload, blobId: String!, blobs: [Upload!], input: SubmitAudioTranscriptionInput, workspaceId: String!): TranscriptionResultType submitTranscriptTask(blob: Upload, blobId: String!, blobs: [Upload!], input: SubmitAudioTranscriptionInput, workspaceId: String!): TranscriptionResultType
testWorkspaceByokConfig(input: TestWorkspaceByokConfigInput!): TestWorkspaceByokConfigResultType!
unlinkCalendarAccount(accountId: String!): Boolean! unlinkCalendarAccount(accountId: String!): Boolean!
"""update app configuration""" """update app configuration"""
@@ -1690,6 +1730,7 @@ type Mutation {
"""Upload a comment attachment and return the access url""" """Upload a comment attachment and return the access url"""
uploadCommentAttachment(attachment: Upload!, docId: String!, workspaceId: String!): String! uploadCommentAttachment(attachment: Upload!, docId: String!, workspaceId: String!): String!
upsertWorkspaceByokConfig(input: UpsertWorkspaceByokConfigInput!): WorkspaceByokKeyConfigType!
verifyEmail(token: String!): Boolean! verifyEmail(token: String!): Boolean!
} }
@@ -1907,9 +1948,6 @@ type Query {
"""get the whole app configuration""" """get the whole app configuration"""
appConfig: JSONObject! appConfig: JSONObject!
"""Apply updates to a doc using LLM and return the merged markdown."""
applyDocUpdates(docId: String!, op: String!, updates: String!, workspaceId: String!): String! @deprecated(reason: "use Mutation.applyDocUpdates")
"""Get current user""" """Get current user"""
currentUser: UserType currentUser: UserType
error(name: ErrorNames!): ErrorDataUnion! error(name: ErrorNames!): ErrorDataUnion!
@@ -2013,6 +2051,12 @@ input RemoveContextFileInput {
fileId: String! fileId: String!
} }
input ReorderWorkspaceByokConfigsInput {
ids: [ID!]!
storage: ByokKeyStorage!
workspaceId: String!
}
input ReplyCreateInput { input ReplyCreateInput {
commentId: ID! commentId: ID!
content: JSONObject! content: JSONObject!
@@ -2344,6 +2388,21 @@ enum SubscriptionVariant {
Onetime Onetime
} }
input TestWorkspaceByokConfigInput {
apiKey: String
configId: ID
endpoint: String
provider: ByokProvider!
storage: ByokKeyStorage!
workspaceId: String!
}
type TestWorkspaceByokConfigResultType {
message: String
ok: Boolean!
status: ByokKeyTestStatus!
}
enum TimeBucket { enum TimeBucket {
Day Day
Minute Minute
@@ -2501,6 +2560,19 @@ input UpdateWorkspaceInput {
"""The `Upload` scalar type represents a file upload.""" """The `Upload` scalar type represents a file upload."""
scalar Upload scalar Upload
input UpsertWorkspaceByokConfigInput {
apiKey: String
description: String
enabled: Boolean
endpoint: String
id: ID
name: String!
provider: ByokProvider!
sortOrder: SafeInt
storage: ByokKeyStorage!
workspaceId: String!
}
type UserImportFailedType { type UserImportFailedType {
email: String! email: String!
error: String! error: String!
@@ -2604,6 +2676,53 @@ type VersionRejectedDataType {
version: String! version: String!
} }
type WorkspaceByokCapabilityWarningType {
featureKind: String!
reason: String!
requiredProviders: [ByokProvider!]!
}
type WorkspaceByokKeyConfigType {
capabilities: [String!]!
configured: Boolean!
description: String
disabledReason: String
enabled: Boolean!
endpoint: String
endpointEditable: Boolean!
id: ID!
lastError: String
lastErrorAt: DateTime
lastTestError: String
lastTestedAt: DateTime
lastUsedAt: DateTime
name: String!
provider: ByokProvider!
sortOrder: SafeInt!
storage: ByokKeyStorage!
testStatus: ByokKeyTestStatus!
}
type WorkspaceByokSettingsType {
allowedProviders: [ByokProvider!]!
customEndpointSupported: Boolean!
entitled: Boolean!
entitlementRequired: [String!]!
hasAiPlan: Boolean!
keys: [WorkspaceByokKeyConfigType!]!
localEntitled: Boolean!
localStorageSupported: Boolean!
serverEntitled: Boolean!
warnings: [WorkspaceByokCapabilityWarningType!]!
workspaceId: String!
}
type WorkspaceByokUsagePointType {
date: DateTime!
featureKind: String!
totalTokens: SafeInt!
}
input WorkspaceCalendarItemInput { input WorkspaceCalendarItemInput {
colorOverride: String colorOverride: String
sortOrder: Int sortOrder: Int
@@ -2722,6 +2841,8 @@ type WorkspaceType {
"""Blobs size of workspace""" """Blobs size of workspace"""
blobsSize: Int! blobsSize: Int!
byokSettings: WorkspaceByokSettingsType!
byokUsage(from: DateTime!, to: DateTime!): [WorkspaceByokUsagePointType!]!
calendars: [WorkspaceCalendarObjectType!]! calendars: [WorkspaceCalendarObjectType!]!
"""Get comment changes of a doc""" """Get comment changes of a doc"""

View File

@@ -1,13 +0,0 @@
mutation applyDocUpdates(
$workspaceId: String!
$docId: String!
$op: String!
$updates: String!
) {
applyDocUpdates(
workspaceId: $workspaceId
docId: $docId
op: $op
updates: $updates
)
}

View File

@@ -1033,19 +1033,6 @@ export const uploadCommentAttachmentMutation = {
file: true, file: true,
}; };
export const applyDocUpdatesMutation = {
id: 'applyDocUpdatesMutation' as const,
op: 'applyDocUpdates',
query: `mutation applyDocUpdates($workspaceId: String!, $docId: String!, $op: String!, $updates: String!) {
applyDocUpdates(
workspaceId: $workspaceId
docId: $docId
op: $op
updates: $updates
)
}`,
};
export const addContextBlobMutation = { export const addContextBlobMutation = {
id: 'addContextBlobMutation' as const, id: 'addContextBlobMutation' as const,
op: 'addContextBlob', op: 'addContextBlob',
@@ -2997,6 +2984,117 @@ export const workspaceBlobQuotaQuery = {
}`, }`,
}; };
export const clearWorkspaceByokConfigsMutation = {
id: 'clearWorkspaceByokConfigsMutation' as const,
op: 'clearWorkspaceByokConfigs',
query: `mutation clearWorkspaceByokConfigs($workspaceId: String!) {
clearWorkspaceByokConfigs(workspaceId: $workspaceId)
}`,
};
export const deleteWorkspaceByokConfigMutation = {
id: 'deleteWorkspaceByokConfigMutation' as const,
op: 'deleteWorkspaceByokConfig',
query: `mutation deleteWorkspaceByokConfig($workspaceId: String!, $id: ID!) {
deleteWorkspaceByokConfig(workspaceId: $workspaceId, id: $id)
}`,
};
export const reorderWorkspaceByokConfigsMutation = {
id: 'reorderWorkspaceByokConfigsMutation' as const,
op: 'reorderWorkspaceByokConfigs',
query: `mutation reorderWorkspaceByokConfigs($input: ReorderWorkspaceByokConfigsInput!) {
reorderWorkspaceByokConfigs(input: $input) {
id
sortOrder
}
}`,
};
export const testWorkspaceByokConfigMutation = {
id: 'testWorkspaceByokConfigMutation' as const,
op: 'testWorkspaceByokConfig',
query: `mutation testWorkspaceByokConfig($input: TestWorkspaceByokConfigInput!) {
testWorkspaceByokConfig(input: $input) {
ok
status
message
}
}`,
};
export const upsertWorkspaceByokConfigMutation = {
id: 'upsertWorkspaceByokConfigMutation' as const,
op: 'upsertWorkspaceByokConfig',
query: `mutation upsertWorkspaceByokConfig($input: UpsertWorkspaceByokConfigInput!) {
upsertWorkspaceByokConfig(input: $input) {
id
}
}`,
};
export const createWorkspaceByokLocalLeaseMutation = {
id: 'createWorkspaceByokLocalLeaseMutation' as const,
op: 'createWorkspaceByokLocalLease',
query: `mutation createWorkspaceByokLocalLease($input: CreateWorkspaceByokLocalLeaseInput!) {
createWorkspaceByokLocalLease(input: $input) {
leaseId
expiresAt
}
}`,
};
export const workspaceByokSettingsQuery = {
id: 'workspaceByokSettingsQuery' as const,
op: 'workspaceByokSettings',
query: `query workspaceByokSettings($id: String!, $from: DateTime!, $to: DateTime!) {
workspace(id: $id) {
id
byokSettings {
workspaceId
entitled
serverEntitled
localEntitled
entitlementRequired
allowedProviders
localStorageSupported
customEndpointSupported
hasAiPlan
keys {
id
provider
name
description
storage
configured
enabled
endpoint
endpointEditable
sortOrder
capabilities
testStatus
disabledReason
lastTestedAt
lastTestError
lastUsedAt
lastErrorAt
lastError
}
warnings {
featureKind
reason
requiredProviders
}
}
byokUsage(from: $from, to: $to) {
date
featureKind
totalTokens
}
}
}`,
};
export const getWorkspaceConfigQuery = { export const getWorkspaceConfigQuery = {
id: 'getWorkspaceConfigQuery' as const, id: 'getWorkspaceConfigQuery' as const,
op: 'getWorkspaceConfig', op: 'getWorkspaceConfig',

View File

@@ -0,0 +1,3 @@
mutation clearWorkspaceByokConfigs($workspaceId: String!) {
clearWorkspaceByokConfigs(workspaceId: $workspaceId)
}

View File

@@ -0,0 +1,3 @@
mutation deleteWorkspaceByokConfig($workspaceId: String!, $id: ID!) {
deleteWorkspaceByokConfig(workspaceId: $workspaceId, id: $id)
}

View File

@@ -0,0 +1,8 @@
mutation reorderWorkspaceByokConfigs(
$input: ReorderWorkspaceByokConfigsInput!
) {
reorderWorkspaceByokConfigs(input: $input) {
id
sortOrder
}
}

View File

@@ -0,0 +1,7 @@
mutation testWorkspaceByokConfig($input: TestWorkspaceByokConfigInput!) {
testWorkspaceByokConfig(input: $input) {
ok
status
message
}
}

View File

@@ -0,0 +1,5 @@
mutation upsertWorkspaceByokConfig($input: UpsertWorkspaceByokConfigInput!) {
upsertWorkspaceByokConfig(input: $input) {
id
}
}

View File

@@ -0,0 +1,8 @@
mutation createWorkspaceByokLocalLease(
$input: CreateWorkspaceByokLocalLeaseInput!
) {
createWorkspaceByokLocalLease(input: $input) {
leaseId
expiresAt
}
}

View File

@@ -0,0 +1,46 @@
query workspaceByokSettings($id: String!, $from: DateTime!, $to: DateTime!) {
workspace(id: $id) {
id
byokSettings {
workspaceId
entitled
serverEntitled
localEntitled
entitlementRequired
allowedProviders
localStorageSupported
customEndpointSupported
hasAiPlan
keys {
id
provider
name
description
storage
configured
enabled
endpoint
endpointEditable
sortOrder
capabilities
testStatus
disabledReason
lastTestedAt
lastTestError
lastUsedAt
lastErrorAt
lastError
}
warnings {
featureKind
reason
requiredProviders
}
}
byokUsage(from: $from, to: $to) {
date
featureKind
totalTokens
}
}
}

View File

@@ -347,6 +347,24 @@ export interface BlobUploadedPart {
partNumber: Scalars['Int']['output']; partNumber: Scalars['Int']['output'];
} }
export enum ByokKeyStorage {
local = 'local',
server = 'server',
}
export enum ByokKeyTestStatus {
failed = 'failed',
passed = 'passed',
untested = 'untested',
}
export enum ByokProvider {
anthropic = 'anthropic',
fal = 'fal',
gemini = 'gemini',
openai = 'openai',
}
export interface CalendarAccountObjectType { export interface CalendarAccountObjectType {
__typename?: 'CalendarAccountObjectType'; __typename?: 'CalendarAccountObjectType';
calendars: Array<CalendarSubscriptionObjectType>; calendars: Array<CalendarSubscriptionObjectType>;
@@ -868,6 +886,27 @@ export interface CreateUserInput {
password?: InputMaybe<Scalars['String']['input']>; password?: InputMaybe<Scalars['String']['input']>;
} }
export interface CreateWorkspaceByokLocalLeaseInput {
providers: Array<CreateWorkspaceByokLocalLeaseProviderInput>;
workspaceId: Scalars['String']['input'];
}
export interface CreateWorkspaceByokLocalLeaseProviderInput {
apiKey: Scalars['String']['input'];
description?: InputMaybe<Scalars['String']['input']>;
enabled?: InputMaybe<Scalars['Boolean']['input']>;
endpoint?: InputMaybe<Scalars['String']['input']>;
name: Scalars['String']['input'];
provider: ByokProvider;
sortOrder?: InputMaybe<Scalars['SafeInt']['input']>;
}
export interface CreateWorkspaceByokLocalLeaseResultType {
__typename?: 'CreateWorkspaceByokLocalLeaseResultType';
expiresAt: Scalars['DateTime']['output'];
leaseId: Scalars['String']['output'];
}
export interface CredentialsRequirementType { export interface CredentialsRequirementType {
__typename?: 'CredentialsRequirementType'; __typename?: 'CredentialsRequirementType';
password: PasswordLimitsType; password: PasswordLimitsType;
@@ -1727,8 +1766,6 @@ export interface Mutation {
addWorkspaceFeature: Scalars['Boolean']['output']; addWorkspaceFeature: Scalars['Boolean']['output'];
/** Update workspace flags and features for admin */ /** Update workspace flags and features for admin */
adminUpdateWorkspace: Maybe<AdminWorkspace>; adminUpdateWorkspace: Maybe<AdminWorkspace>;
/** Apply updates to a doc using LLM and return the merged markdown. */
applyDocUpdates: Scalars['String']['output'];
approveMember: Scalars['Boolean']['output']; approveMember: Scalars['Boolean']['output'];
/** Ban an user */ /** Ban an user */
banUser: UserType; banUser: UserType;
@@ -1737,6 +1774,7 @@ export interface Mutation {
changePassword: Scalars['Boolean']['output']; changePassword: Scalars['Boolean']['output'];
/** Cleanup sessions */ /** Cleanup sessions */
cleanupCopilotSession: Array<Scalars['String']['output']>; cleanupCopilotSession: Array<Scalars['String']['output']>;
clearWorkspaceByokConfigs: Scalars['Boolean']['output'];
completeBlobUpload: Scalars['String']['output']; completeBlobUpload: Scalars['String']['output'];
createBlobUpload: BlobUploadInit; createBlobUpload: BlobUploadInit;
/** Create change password url */ /** Create change password url */
@@ -1764,6 +1802,7 @@ export interface Mutation {
createUser: UserType; createUser: UserType;
/** Create a new workspace */ /** Create a new workspace */
createWorkspace: WorkspaceType; createWorkspace: WorkspaceType;
createWorkspaceByokLocalLease: CreateWorkspaceByokLocalLeaseResultType;
deactivateLicense: Scalars['Boolean']['output']; deactivateLicense: Scalars['Boolean']['output'];
deleteAccount: DeleteAccount; deleteAccount: DeleteAccount;
deleteBlob: Scalars['Boolean']['output']; deleteBlob: Scalars['Boolean']['output'];
@@ -1774,6 +1813,7 @@ export interface Mutation {
/** Delete a user account */ /** Delete a user account */
deleteUser: DeleteAccount; deleteUser: DeleteAccount;
deleteWorkspace: Scalars['Boolean']['output']; deleteWorkspace: Scalars['Boolean']['output'];
deleteWorkspaceByokConfig: Scalars['Boolean']['output'];
/** Reenable an banned user */ /** Reenable an banned user */
enableUser: UserType; enableUser: UserType;
/** Create a chat session */ /** Create a chat session */
@@ -1815,6 +1855,7 @@ export interface Mutation {
/** Remove workspace embedding files */ /** Remove workspace embedding files */
removeWorkspaceEmbeddingFiles: Scalars['Boolean']['output']; removeWorkspaceEmbeddingFiles: Scalars['Boolean']['output'];
removeWorkspaceFeature: Scalars['Boolean']['output']; removeWorkspaceFeature: Scalars['Boolean']['output'];
reorderWorkspaceByokConfigs: Array<WorkspaceByokKeyConfigType>;
/** Request to apply the subscription in advance */ /** Request to apply the subscription in advance */
requestApplySubscription: Array<SubscriptionType>; requestApplySubscription: Array<SubscriptionType>;
/** Resolve a comment or not */ /** Resolve a comment or not */
@@ -1835,6 +1876,7 @@ export interface Mutation {
setBlob: Scalars['String']['output']; setBlob: Scalars['String']['output'];
settleTranscriptTask: Maybe<TranscriptionResultType>; settleTranscriptTask: Maybe<TranscriptionResultType>;
submitTranscriptTask: Maybe<TranscriptionResultType>; submitTranscriptTask: Maybe<TranscriptionResultType>;
testWorkspaceByokConfig: TestWorkspaceByokConfigResultType;
unlinkCalendarAccount: Scalars['Boolean']['output']; unlinkCalendarAccount: Scalars['Boolean']['output'];
/** update app configuration */ /** update app configuration */
updateAppConfig: Scalars['JSONObject']['output']; updateAppConfig: Scalars['JSONObject']['output'];
@@ -1864,6 +1906,7 @@ export interface Mutation {
uploadAvatar: UserType; uploadAvatar: UserType;
/** Upload a comment attachment and return the access url */ /** Upload a comment attachment and return the access url */
uploadCommentAttachment: Scalars['String']['output']; uploadCommentAttachment: Scalars['String']['output'];
upsertWorkspaceByokConfig: WorkspaceByokKeyConfigType;
verifyEmail: Scalars['Boolean']['output']; verifyEmail: Scalars['Boolean']['output'];
} }
@@ -1915,13 +1958,6 @@ export interface MutationAdminUpdateWorkspaceArgs {
input: AdminUpdateWorkspaceInput; input: AdminUpdateWorkspaceInput;
} }
export interface MutationApplyDocUpdatesArgs {
docId: Scalars['String']['input'];
op: Scalars['String']['input'];
updates: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationApproveMemberArgs { export interface MutationApproveMemberArgs {
userId: Scalars['String']['input']; userId: Scalars['String']['input'];
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
@@ -1952,6 +1988,11 @@ export interface MutationCleanupCopilotSessionArgs {
options: DeleteSessionInput; options: DeleteSessionInput;
} }
export interface MutationClearWorkspaceByokConfigsArgs {
provider?: InputMaybe<ByokProvider>;
workspaceId: Scalars['String']['input'];
}
export interface MutationCompleteBlobUploadArgs { export interface MutationCompleteBlobUploadArgs {
key: Scalars['String']['input']; key: Scalars['String']['input'];
parts?: InputMaybe<Array<BlobUploadPartInput>>; parts?: InputMaybe<Array<BlobUploadPartInput>>;
@@ -2017,6 +2058,10 @@ export interface MutationCreateWorkspaceArgs {
init?: InputMaybe<Scalars['Upload']['input']>; init?: InputMaybe<Scalars['Upload']['input']>;
} }
export interface MutationCreateWorkspaceByokLocalLeaseArgs {
input: CreateWorkspaceByokLocalLeaseInput;
}
export interface MutationDeactivateLicenseArgs { export interface MutationDeactivateLicenseArgs {
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }
@@ -2044,6 +2089,11 @@ export interface MutationDeleteWorkspaceArgs {
id: Scalars['String']['input']; id: Scalars['String']['input'];
} }
export interface MutationDeleteWorkspaceByokConfigArgs {
id: Scalars['ID']['input'];
workspaceId: Scalars['String']['input'];
}
export interface MutationEnableUserArgs { export interface MutationEnableUserArgs {
id: Scalars['String']['input']; id: Scalars['String']['input'];
} }
@@ -2153,6 +2203,10 @@ export interface MutationRemoveWorkspaceFeatureArgs {
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }
export interface MutationReorderWorkspaceByokConfigsArgs {
input: ReorderWorkspaceByokConfigsInput;
}
export interface MutationRequestApplySubscriptionArgs { export interface MutationRequestApplySubscriptionArgs {
transactionId: Scalars['String']['input']; transactionId: Scalars['String']['input'];
} }
@@ -2241,6 +2295,10 @@ export interface MutationSubmitTranscriptTaskArgs {
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }
export interface MutationTestWorkspaceByokConfigArgs {
input: TestWorkspaceByokConfigInput;
}
export interface MutationUnlinkCalendarAccountArgs { export interface MutationUnlinkCalendarAccountArgs {
accountId: Scalars['String']['input']; accountId: Scalars['String']['input'];
} }
@@ -2323,6 +2381,10 @@ export interface MutationUploadCommentAttachmentArgs {
workspaceId: Scalars['String']['input']; workspaceId: Scalars['String']['input'];
} }
export interface MutationUpsertWorkspaceByokConfigArgs {
input: UpsertWorkspaceByokConfigInput;
}
export interface MutationVerifyEmailArgs { export interface MutationVerifyEmailArgs {
token: Scalars['String']['input']; token: Scalars['String']['input'];
} }
@@ -2545,11 +2607,6 @@ export interface Query {
adminWorkspacesCount: Scalars['Int']['output']; adminWorkspacesCount: Scalars['Int']['output'];
/** get the whole app configuration */ /** get the whole app configuration */
appConfig: Scalars['JSONObject']['output']; appConfig: Scalars['JSONObject']['output'];
/**
* Apply updates to a doc using LLM and return the merged markdown.
* @deprecated use Mutation.applyDocUpdates
*/
applyDocUpdates: Scalars['String']['output'];
/** Get current user */ /** Get current user */
currentUser: Maybe<UserType>; currentUser: Maybe<UserType>;
error: ErrorDataUnion; error: ErrorDataUnion;
@@ -2608,13 +2665,6 @@ export interface QueryAdminWorkspacesCountArgs {
filter: ListWorkspaceInput; filter: ListWorkspaceInput;
} }
export interface QueryApplyDocUpdatesArgs {
docId: Scalars['String']['input'];
op: Scalars['String']['input'];
updates: Scalars['String']['input'];
workspaceId: Scalars['String']['input'];
}
export interface QueryErrorArgs { export interface QueryErrorArgs {
name: ErrorNames; name: ErrorNames;
} }
@@ -2723,6 +2773,12 @@ export interface RemoveContextFileInput {
fileId: Scalars['String']['input']; fileId: Scalars['String']['input'];
} }
export interface ReorderWorkspaceByokConfigsInput {
ids: Array<Scalars['ID']['input']>;
storage: ByokKeyStorage;
workspaceId: Scalars['String']['input'];
}
export interface ReplyCreateInput { export interface ReplyCreateInput {
commentId: Scalars['ID']['input']; commentId: Scalars['ID']['input'];
content: Scalars['JSONObject']['input']; content: Scalars['JSONObject']['input'];
@@ -3046,6 +3102,22 @@ export enum SubscriptionVariant {
Onetime = 'Onetime', Onetime = 'Onetime',
} }
export interface TestWorkspaceByokConfigInput {
apiKey?: InputMaybe<Scalars['String']['input']>;
configId?: InputMaybe<Scalars['ID']['input']>;
endpoint?: InputMaybe<Scalars['String']['input']>;
provider: ByokProvider;
storage: ByokKeyStorage;
workspaceId: Scalars['String']['input'];
}
export interface TestWorkspaceByokConfigResultType {
__typename?: 'TestWorkspaceByokConfigResultType';
message: Maybe<Scalars['String']['output']>;
ok: Scalars['Boolean']['output'];
status: ByokKeyTestStatus;
}
export enum TimeBucket { export enum TimeBucket {
Day = 'Day', Day = 'Day',
Minute = 'Minute', Minute = 'Minute',
@@ -3208,6 +3280,19 @@ export interface UpdateWorkspaceInput {
public?: InputMaybe<Scalars['Boolean']['input']>; public?: InputMaybe<Scalars['Boolean']['input']>;
} }
export interface UpsertWorkspaceByokConfigInput {
apiKey?: InputMaybe<Scalars['String']['input']>;
description?: InputMaybe<Scalars['String']['input']>;
enabled?: InputMaybe<Scalars['Boolean']['input']>;
endpoint?: InputMaybe<Scalars['String']['input']>;
id?: InputMaybe<Scalars['ID']['input']>;
name: Scalars['String']['input'];
provider: ByokProvider;
sortOrder?: InputMaybe<Scalars['SafeInt']['input']>;
storage: ByokKeyStorage;
workspaceId: Scalars['String']['input'];
}
export interface UserImportFailedType { export interface UserImportFailedType {
__typename?: 'UserImportFailedType'; __typename?: 'UserImportFailedType';
email: Scalars['String']['output']; email: Scalars['String']['output'];
@@ -3323,6 +3408,57 @@ export interface VersionRejectedDataType {
version: Scalars['String']['output']; version: Scalars['String']['output'];
} }
export interface WorkspaceByokCapabilityWarningType {
__typename?: 'WorkspaceByokCapabilityWarningType';
featureKind: Scalars['String']['output'];
reason: Scalars['String']['output'];
requiredProviders: Array<ByokProvider>;
}
export interface WorkspaceByokKeyConfigType {
__typename?: 'WorkspaceByokKeyConfigType';
capabilities: Array<Scalars['String']['output']>;
configured: Scalars['Boolean']['output'];
description: Maybe<Scalars['String']['output']>;
disabledReason: Maybe<Scalars['String']['output']>;
enabled: Scalars['Boolean']['output'];
endpoint: Maybe<Scalars['String']['output']>;
endpointEditable: Scalars['Boolean']['output'];
id: Scalars['ID']['output'];
lastError: Maybe<Scalars['String']['output']>;
lastErrorAt: Maybe<Scalars['DateTime']['output']>;
lastTestError: Maybe<Scalars['String']['output']>;
lastTestedAt: Maybe<Scalars['DateTime']['output']>;
lastUsedAt: Maybe<Scalars['DateTime']['output']>;
name: Scalars['String']['output'];
provider: ByokProvider;
sortOrder: Scalars['SafeInt']['output'];
storage: ByokKeyStorage;
testStatus: ByokKeyTestStatus;
}
export interface WorkspaceByokSettingsType {
__typename?: 'WorkspaceByokSettingsType';
allowedProviders: Array<ByokProvider>;
customEndpointSupported: Scalars['Boolean']['output'];
entitled: Scalars['Boolean']['output'];
entitlementRequired: Array<Scalars['String']['output']>;
hasAiPlan: Scalars['Boolean']['output'];
keys: Array<WorkspaceByokKeyConfigType>;
localEntitled: Scalars['Boolean']['output'];
localStorageSupported: Scalars['Boolean']['output'];
serverEntitled: Scalars['Boolean']['output'];
warnings: Array<WorkspaceByokCapabilityWarningType>;
workspaceId: Scalars['String']['output'];
}
export interface WorkspaceByokUsagePointType {
__typename?: 'WorkspaceByokUsagePointType';
date: Scalars['DateTime']['output'];
featureKind: Scalars['String']['output'];
totalTokens: Scalars['SafeInt']['output'];
}
export interface WorkspaceCalendarItemInput { export interface WorkspaceCalendarItemInput {
colorOverride?: InputMaybe<Scalars['String']['input']>; colorOverride?: InputMaybe<Scalars['String']['input']>;
sortOrder?: InputMaybe<Scalars['Int']['input']>; sortOrder?: InputMaybe<Scalars['Int']['input']>;
@@ -3453,6 +3589,8 @@ export interface WorkspaceType {
blobs: Array<ListedBlob>; blobs: Array<ListedBlob>;
/** Blobs size of workspace */ /** Blobs size of workspace */
blobsSize: Scalars['Int']['output']; blobsSize: Scalars['Int']['output'];
byokSettings: WorkspaceByokSettingsType;
byokUsage: Array<WorkspaceByokUsagePointType>;
calendars: Array<WorkspaceCalendarObjectType>; calendars: Array<WorkspaceCalendarObjectType>;
/** Get comment changes of a doc */ /** Get comment changes of a doc */
commentChanges: PaginatedCommentChangeObjectType; commentChanges: PaginatedCommentChangeObjectType;
@@ -3526,6 +3664,11 @@ export interface WorkspaceTypeBlobUploadPartUrlArgs {
uploadId: Scalars['String']['input']; uploadId: Scalars['String']['input'];
} }
export interface WorkspaceTypeByokUsageArgs {
from: Scalars['DateTime']['input'];
to: Scalars['DateTime']['input'];
}
export interface WorkspaceTypeCommentChangesArgs { export interface WorkspaceTypeCommentChangesArgs {
docId: Scalars['String']['input']; docId: Scalars['String']['input'];
pagination: PaginationInput; pagination: PaginationInput;
@@ -4642,18 +4785,6 @@ export type UploadCommentAttachmentMutation = {
uploadCommentAttachment: string; uploadCommentAttachment: string;
}; };
export type ApplyDocUpdatesMutationVariables = Exact<{
workspaceId: Scalars['String']['input'];
docId: Scalars['String']['input'];
op: Scalars['String']['input'];
updates: Scalars['String']['input'];
}>;
export type ApplyDocUpdatesMutation = {
__typename?: 'Mutation';
applyDocUpdates: string;
};
export type AddContextBlobMutationVariables = Exact<{ export type AddContextBlobMutationVariables = Exact<{
options: AddContextBlobInput; options: AddContextBlobInput;
}>; }>;
@@ -7478,6 +7609,136 @@ export type WorkspaceBlobQuotaQuery = {
}; };
}; };
export type ClearWorkspaceByokConfigsMutationVariables = Exact<{
workspaceId: Scalars['String']['input'];
}>;
export type ClearWorkspaceByokConfigsMutation = {
__typename?: 'Mutation';
clearWorkspaceByokConfigs: boolean;
};
export type DeleteWorkspaceByokConfigMutationVariables = Exact<{
workspaceId: Scalars['String']['input'];
id: Scalars['ID']['input'];
}>;
export type DeleteWorkspaceByokConfigMutation = {
__typename?: 'Mutation';
deleteWorkspaceByokConfig: boolean;
};
export type ReorderWorkspaceByokConfigsMutationVariables = Exact<{
input: ReorderWorkspaceByokConfigsInput;
}>;
export type ReorderWorkspaceByokConfigsMutation = {
__typename?: 'Mutation';
reorderWorkspaceByokConfigs: Array<{
__typename?: 'WorkspaceByokKeyConfigType';
id: string;
sortOrder: number;
}>;
};
export type TestWorkspaceByokConfigMutationVariables = Exact<{
input: TestWorkspaceByokConfigInput;
}>;
export type TestWorkspaceByokConfigMutation = {
__typename?: 'Mutation';
testWorkspaceByokConfig: {
__typename?: 'TestWorkspaceByokConfigResultType';
ok: boolean;
status: ByokKeyTestStatus;
message: string | null;
};
};
export type UpsertWorkspaceByokConfigMutationVariables = Exact<{
input: UpsertWorkspaceByokConfigInput;
}>;
export type UpsertWorkspaceByokConfigMutation = {
__typename?: 'Mutation';
upsertWorkspaceByokConfig: {
__typename?: 'WorkspaceByokKeyConfigType';
id: string;
};
};
export type CreateWorkspaceByokLocalLeaseMutationVariables = Exact<{
input: CreateWorkspaceByokLocalLeaseInput;
}>;
export type CreateWorkspaceByokLocalLeaseMutation = {
__typename?: 'Mutation';
createWorkspaceByokLocalLease: {
__typename?: 'CreateWorkspaceByokLocalLeaseResultType';
leaseId: string;
expiresAt: string;
};
};
export type WorkspaceByokSettingsQueryVariables = Exact<{
id: Scalars['String']['input'];
from: Scalars['DateTime']['input'];
to: Scalars['DateTime']['input'];
}>;
export type WorkspaceByokSettingsQuery = {
__typename?: 'Query';
workspace: {
__typename?: 'WorkspaceType';
id: string;
byokSettings: {
__typename?: 'WorkspaceByokSettingsType';
workspaceId: string;
entitled: boolean;
serverEntitled: boolean;
localEntitled: boolean;
entitlementRequired: Array<string>;
allowedProviders: Array<ByokProvider>;
localStorageSupported: boolean;
customEndpointSupported: boolean;
hasAiPlan: boolean;
keys: Array<{
__typename?: 'WorkspaceByokKeyConfigType';
id: string;
provider: ByokProvider;
name: string;
description: string | null;
storage: ByokKeyStorage;
configured: boolean;
enabled: boolean;
endpoint: string | null;
endpointEditable: boolean;
sortOrder: number;
capabilities: Array<string>;
testStatus: ByokKeyTestStatus;
disabledReason: string | null;
lastTestedAt: string | null;
lastTestError: string | null;
lastUsedAt: string | null;
lastErrorAt: string | null;
lastError: string | null;
}>;
warnings: Array<{
__typename?: 'WorkspaceByokCapabilityWarningType';
featureKind: string;
reason: string;
requiredProviders: Array<ByokProvider>;
}>;
};
byokUsage: Array<{
__typename?: 'WorkspaceByokUsagePointType';
date: string;
featureKind: string;
totalTokens: number;
}>;
};
};
export type GetWorkspaceConfigQueryVariables = Exact<{ export type GetWorkspaceConfigQueryVariables = Exact<{
id: Scalars['String']['input']; id: Scalars['String']['input'];
}>; }>;
@@ -8104,6 +8365,11 @@ export type Queries =
variables: WorkspaceBlobQuotaQueryVariables; variables: WorkspaceBlobQuotaQueryVariables;
response: WorkspaceBlobQuotaQuery; response: WorkspaceBlobQuotaQuery;
} }
| {
name: 'workspaceByokSettingsQuery';
variables: WorkspaceByokSettingsQueryVariables;
response: WorkspaceByokSettingsQuery;
}
| { | {
name: 'getWorkspaceConfigQuery'; name: 'getWorkspaceConfigQuery';
variables: GetWorkspaceConfigQueryVariables; variables: GetWorkspaceConfigQueryVariables;
@@ -8301,11 +8567,6 @@ export type Mutations =
variables: UploadCommentAttachmentMutationVariables; variables: UploadCommentAttachmentMutationVariables;
response: UploadCommentAttachmentMutation; response: UploadCommentAttachmentMutation;
} }
| {
name: 'applyDocUpdatesMutation';
variables: ApplyDocUpdatesMutationVariables;
response: ApplyDocUpdatesMutation;
}
| { | {
name: 'addContextBlobMutation'; name: 'addContextBlobMutation';
variables: AddContextBlobMutationVariables; variables: AddContextBlobMutationVariables;
@@ -8606,6 +8867,36 @@ export type Mutations =
variables: VerifyEmailMutationVariables; variables: VerifyEmailMutationVariables;
response: VerifyEmailMutation; response: VerifyEmailMutation;
} }
| {
name: 'clearWorkspaceByokConfigsMutation';
variables: ClearWorkspaceByokConfigsMutationVariables;
response: ClearWorkspaceByokConfigsMutation;
}
| {
name: 'deleteWorkspaceByokConfigMutation';
variables: DeleteWorkspaceByokConfigMutationVariables;
response: DeleteWorkspaceByokConfigMutation;
}
| {
name: 'reorderWorkspaceByokConfigsMutation';
variables: ReorderWorkspaceByokConfigsMutationVariables;
response: ReorderWorkspaceByokConfigsMutation;
}
| {
name: 'testWorkspaceByokConfigMutation';
variables: TestWorkspaceByokConfigMutationVariables;
response: TestWorkspaceByokConfigMutation;
}
| {
name: 'upsertWorkspaceByokConfigMutation';
variables: UpsertWorkspaceByokConfigMutationVariables;
response: UpsertWorkspaceByokConfigMutation;
}
| {
name: 'createWorkspaceByokLocalLeaseMutation';
variables: CreateWorkspaceByokLocalLeaseMutationVariables;
response: CreateWorkspaceByokLocalLeaseMutation;
}
| { | {
name: 'setEnableAiMutation'; name: 'setEnableAiMutation';
variables: SetEnableAiMutationVariables; variables: SetEnableAiMutationVariables;

View File

@@ -313,6 +313,18 @@
"type": "Boolean", "type": "Boolean",
"desc": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>" "desc": "Whether to enable the copilot plugin. <br> Document: <a href=\"https://docs.affine.pro/self-host-affine/administer/ai\" target=\"_blank\">https://docs.affine.pro/self-host-affine/administer/ai</a>"
}, },
"byok.enabled": {
"type": "Boolean",
"desc": "Whether to enable workspace BYOK."
},
"byok.allowedProviders": {
"type": "Array",
"desc": "The allowlist for workspace BYOK providers."
},
"byok.allowCustomEndpoint": {
"type": "Boolean",
"desc": "Whether workspace BYOK custom endpoints are accepted."
},
"providers.profiles": { "providers.profiles": {
"type": "Array", "type": "Array",
"desc": "The profile list for copilot providers." "desc": "The profile list for copilot providers."
@@ -342,10 +354,6 @@
"type": "Object", "type": "Object",
"desc": "The config for the gemini provider in Google Vertex AI." "desc": "The config for the gemini provider in Google Vertex AI."
}, },
"providers.perplexity": {
"type": "Object",
"desc": "The config for the perplexity provider."
},
"providers.anthropic": { "providers.anthropic": {
"type": "Object", "type": "Object",
"desc": "The config for the anthropic provider." "desc": "The config for the anthropic provider."
@@ -354,10 +362,6 @@
"type": "Object", "type": "Object",
"desc": "The config for the anthropic provider in Google Vertex AI." "desc": "The config for the anthropic provider in Google Vertex AI."
}, },
"providers.morph": {
"type": "Object",
"desc": "The config for the morph provider."
},
"unsplash": { "unsplash": {
"type": "Object", "type": "Object",
"desc": "The config for the unsplash key." "desc": "The config for the unsplash key."

View File

@@ -153,7 +153,6 @@ export const KNOWN_CONFIG_GROUPS = [
'scenarios', 'scenarios',
'providers.openai', 'providers.openai',
'providers.gemini', 'providers.gemini',
'providers.perplexity',
'providers.anthropic', 'providers.anthropic',
'providers.fal', 'providers.fal',
'unsplash', 'unsplash',

View File

@@ -0,0 +1,183 @@
import path from 'node:path';
import { app, safeStorage } from 'electron';
import { PersistentJSONFileStorage } from '../shared-storage/json-file';
import type { NamespaceHandlers } from '../type';
const byokStorage = new PersistentJSONFileStorage(
path.join(app.getPath('userData'), 'workspace-byok-keys.json')
);
export function disposeWorkspaceByokStorage() {
byokStorage.dispose();
}
const allowedProviders = new Set(['openai', 'anthropic', 'gemini', 'fal']);
type WorkspaceByokKey = {
id: string;
provider: 'openai' | 'anthropic' | 'gemini' | 'fal';
name: string;
description?: string | null;
apiKey: string;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
};
type WorkspaceByokKeyInput = Omit<WorkspaceByokKey, 'apiKey'> & {
apiKey?: string | null;
};
function assertSupported() {
if (!safeStorage.isEncryptionAvailable()) {
throw new Error('Secure BYOK key storage is not available.');
}
}
function hasOwnField(
key: WorkspaceByokKeyInput,
field: keyof WorkspaceByokKey
) {
return Object.prototype.hasOwnProperty.call(key, field);
}
function normalizeKey(
key: WorkspaceByokKeyInput,
existing?: WorkspaceByokKey,
defaultSortOrder = 0
): WorkspaceByokKey {
if (!allowedProviders.has(key.provider)) {
throw new Error('Unsupported BYOK provider.');
}
const apiKey = key.apiKey ?? existing?.apiKey;
if (!key.id || !key.name || !apiKey) {
throw new Error('Invalid BYOK key.');
}
return {
id: key.id,
provider: key.provider,
name: key.name,
description: hasOwnField(key, 'description')
? (key.description ?? null)
: (existing?.description ?? null),
apiKey,
endpoint: hasOwnField(key, 'endpoint')
? (key.endpoint ?? null)
: (existing?.endpoint ?? null),
sortOrder: hasOwnField(key, 'sortOrder')
? (key.sortOrder ?? defaultSortOrder)
: (existing?.sortOrder ?? defaultSortOrder),
enabled: hasOwnField(key, 'enabled')
? (key.enabled ?? true)
: (existing?.enabled ?? true),
};
}
function encryptKey(key: WorkspaceByokKey) {
return safeStorage
.encryptString(JSON.stringify(normalizeKey(key)))
.toString('base64');
}
function decryptKey(value: string): WorkspaceByokKey | null {
try {
return normalizeKey(
JSON.parse(safeStorage.decryptString(Buffer.from(value, 'base64')))
);
} catch {
return null;
}
}
function sortWorkspaceKeys(keys: WorkspaceByokKey[]) {
return keys.toSorted((a, b) => (a.sortOrder ?? 0) - (b.sortOrder ?? 0));
}
function readWorkspaceKeys(workspaceId: string): WorkspaceByokKey[] {
assertSupported();
const encryptedKeys = byokStorage.get<string[]>(workspaceId) ?? [];
return sortWorkspaceKeys(
encryptedKeys.flatMap(value => {
const key = decryptKey(value);
return key ? [key] : [];
})
);
}
function writeWorkspaceKeys(workspaceId: string, keys: WorkspaceByokKey[]) {
assertSupported();
byokStorage.set(workspaceId, keys.map(encryptKey));
}
function toPublicKey({ apiKey: _, ...key }: WorkspaceByokKey) {
return {
...key,
storage: 'local',
configured: true,
endpointEditable: false,
testStatus: 'passed',
};
}
export const byokStorageHandlers = {
isSupported: async () => safeStorage.isEncryptionAvailable(),
listWorkspaceKeys: async (_e, workspaceId: string) => {
return readWorkspaceKeys(workspaceId).map(toPublicKey);
},
getWorkspaceLeaseProviders: async (_e, workspaceId: string) => {
return readWorkspaceKeys(workspaceId).filter(key => key.enabled !== false);
},
upsertWorkspaceKey: async (
_e,
workspaceId: string,
key: WorkspaceByokKeyInput
) => {
const keys = readWorkspaceKeys(workspaceId);
const index = keys.findIndex(storedKey => storedKey.id === key.id);
const nextKey = normalizeKey(
key,
index === -1 ? undefined : keys[index],
keys.length
);
if (index === -1) {
keys.push(nextKey);
} else {
keys[index] = nextKey;
}
writeWorkspaceKeys(workspaceId, keys);
return toPublicKey(nextKey);
},
deleteWorkspaceKey: async (_e, workspaceId: string, keyId: string) => {
writeWorkspaceKeys(
workspaceId,
readWorkspaceKeys(workspaceId).filter(key => key.id !== keyId)
);
return true;
},
reorderWorkspaceKeys: async (_e, workspaceId: string, ids: string[]) => {
const keys = readWorkspaceKeys(workspaceId);
const byId = new Map(keys.map(key => [key.id, key]));
const ordered = ids
.map((id, sortOrder) => {
const key = byId.get(id);
byId.delete(id);
return key ? ({ ...key, sortOrder } as WorkspaceByokKey) : null;
})
.filter((key): key is WorkspaceByokKey => !!key);
const nextKeys = sortWorkspaceKeys([
...ordered,
...Array.from(byId.values()).map((key, index) => ({
...key,
sortOrder: ordered.length + index,
})),
]);
writeWorkspaceKeys(workspaceId, nextKeys);
return nextKeys.map(toPublicKey);
},
clearWorkspaceKeys: async (_e, workspaceId: string) => {
byokStorage.del(workspaceId);
return true;
},
} satisfies NamespaceHandlers;

View File

@@ -2,6 +2,7 @@ import { I18n } from '@affine/i18n';
import { ipcMain } from 'electron'; import { ipcMain } from 'electron';
import { AFFINE_API_CHANNEL_NAME } from '../shared/type'; import { AFFINE_API_CHANNEL_NAME } from '../shared/type';
import { byokStorageHandlers } from './byok-storage/handlers';
import { clipboardHandlers } from './clipboard'; import { clipboardHandlers } from './clipboard';
import { configStorageHandlers } from './config-storage'; import { configStorageHandlers } from './config-storage';
import { findInPageHandlers } from './find-in-page'; import { findInPageHandlers } from './find-in-page';
@@ -42,6 +43,7 @@ export const allHandlers = {
recording: recordingHandlers, recording: recordingHandlers,
popup: popupHandlers, popup: popupHandlers,
i18n: i18nHandlers, i18n: i18nHandlers,
byokStorage: byokStorageHandlers,
}; };
export const registerHandlers = () => { export const registerHandlers = () => {

View File

@@ -0,0 +1,151 @@
import path from 'node:path';
import fs from 'fs-extra';
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
const tmpDir = path.join(__dirname, 'tmp-byok-storage');
let disposeWorkspaceByokStorage: (() => void) | undefined;
vi.mock('electron', () => ({
app: {
getPath: () => tmpDir,
on: vi.fn(),
},
safeStorage: {
isEncryptionAvailable: () => true,
encryptString: (value: string) => Buffer.from(value, 'utf-8'),
decryptString: (value: Buffer) => value.toString('utf-8'),
},
}));
beforeEach(async () => {
vi.resetModules();
disposeWorkspaceByokStorage = undefined;
await fs.remove(tmpDir);
});
afterEach(async () => {
disposeWorkspaceByokStorage?.();
vi.resetModules();
await fs.remove(tmpDir);
});
describe('byok storage handlers', () => {
test('stores encrypted local keys and keeps lease providers sorted', async () => {
const { byokStorageHandlers, disposeWorkspaceByokStorage: dispose } =
await import('@affine/electron/main/byok-storage/handlers');
disposeWorkspaceByokStorage = dispose;
const ipcEvent = undefined;
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
id: 'local-openai',
provider: 'openai',
name: 'OpenAI',
apiKey: 'sk-openai',
sortOrder: 1,
});
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
id: 'local-gemini',
provider: 'gemini',
name: 'Gemini',
apiKey: 'sk-gemini',
sortOrder: 0,
});
const list = await byokStorageHandlers.listWorkspaceKeys(
ipcEvent,
'workspace-1'
);
expect(list.map(key => key.id)).toEqual(['local-gemini', 'local-openai']);
expect(JSON.stringify(list)).not.toContain('sk-openai');
const reordered = await byokStorageHandlers.reorderWorkspaceKeys(
ipcEvent,
'workspace-1',
['local-openai', 'local-gemini']
);
expect(reordered.map(key => key.id)).toEqual([
'local-openai',
'local-gemini',
]);
const leaseProviders = await byokStorageHandlers.getWorkspaceLeaseProviders(
ipcEvent,
'workspace-1'
);
expect(leaseProviders.map(key => key.apiKey)).toEqual([
'sk-openai',
'sk-gemini',
]);
await byokStorageHandlers.clearWorkspaceKeys(ipcEvent, 'workspace-1');
await expect(
byokStorageHandlers.listWorkspaceKeys(ipcEvent, 'workspace-1')
).resolves.toEqual([]);
});
test('preserves existing local key fields during partial updates', async () => {
const { byokStorageHandlers, disposeWorkspaceByokStorage: dispose } =
await import('@affine/electron/main/byok-storage/handlers');
disposeWorkspaceByokStorage = dispose;
const ipcEvent = undefined;
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
id: 'local-openai',
provider: 'openai',
name: 'OpenAI',
description: 'Primary key',
apiKey: 'sk-openai',
endpoint: 'https://api.openai.example/v1',
sortOrder: 4,
enabled: false,
});
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
id: 'local-openai',
provider: 'openai',
name: 'OpenAI renamed',
apiKey: 'sk-openai-next',
});
const [publicKey] = await byokStorageHandlers.listWorkspaceKeys(
ipcEvent,
'workspace-1'
);
expect(publicKey).toMatchObject({
id: 'local-openai',
name: 'OpenAI renamed',
description: 'Primary key',
endpoint: 'https://api.openai.example/v1',
sortOrder: 4,
enabled: false,
});
const [leaseProvider] =
await byokStorageHandlers.getWorkspaceLeaseProviders(
ipcEvent,
'workspace-1'
);
expect(leaseProvider).toBeUndefined();
await byokStorageHandlers.upsertWorkspaceKey(ipcEvent, 'workspace-1', {
id: 'local-openai',
provider: 'openai',
name: 'OpenAI renamed again',
enabled: true,
});
const [enabledLeaseProvider] =
await byokStorageHandlers.getWorkspaceLeaseProviders(
ipcEvent,
'workspace-1'
);
expect(enabledLeaseProvider).toMatchObject({
name: 'OpenAI renamed again',
apiKey: 'sk-openai-next',
endpoint: 'https://api.openai.example/v1',
sortOrder: 4,
enabled: true,
});
});
});

View File

@@ -23,7 +23,8 @@ export default defineConfig({
test: { test: {
setupFiles: [resolve(rootDir, './scripts/setup/global.ts')], setupFiles: [resolve(rootDir, './scripts/setup/global.ts')],
include: ['./test/**/*.spec.ts'], include: ['./test/**/*.spec.ts'],
testTimeout: 30000, testTimeout: 60000,
hookTimeout: 30000,
pool: 'forks', pool: 'forks',
maxWorkers: 1, maxWorkers: 1,
coverage: { coverage: {

View File

@@ -1,57 +0,0 @@
// @generated
// This file was automatically generated and should not be edited.
@_exported import ApolloAPI
public class ApplyDocUpdatesMutation: GraphQLMutation {
public static let operationName: String = "applyDocUpdates"
public static let operationDocument: ApolloAPI.OperationDocument = .init(
definition: .init(
#"mutation applyDocUpdates($workspaceId: String!, $docId: String!, $op: String!, $updates: String!) { applyDocUpdates( workspaceId: $workspaceId docId: $docId op: $op updates: $updates ) }"#
))
public var workspaceId: String
public var docId: String
public var op: String
public var updates: String
public init(
workspaceId: String,
docId: String,
op: String,
updates: String
) {
self.workspaceId = workspaceId
self.docId = docId
self.op = op
self.updates = updates
}
public var __variables: Variables? { [
"workspaceId": workspaceId,
"docId": docId,
"op": op,
"updates": updates
] }
public struct Data: AffineGraphQL.SelectionSet {
public let __data: DataDict
public init(_dataDict: DataDict) { __data = _dataDict }
public static var __parentType: any ApolloAPI.ParentType { AffineGraphQL.Objects.Mutation }
public static var __selections: [ApolloAPI.Selection] { [
.field("applyDocUpdates", String.self, arguments: [
"workspaceId": .variable("workspaceId"),
"docId": .variable("docId"),
"op": .variable("op"),
"updates": .variable("updates")
]),
] }
public static var __fulfilledFragments: [any ApolloAPI.SelectionSet.Type] { [
ApplyDocUpdatesMutation.Data.self
] }
/// Apply updates to a doc using LLM and return the merged markdown.
public var applyDocUpdates: String { __data["applyDocUpdates"] }
}
}

View File

@@ -354,12 +354,6 @@ declare global {
files?: ContextMatchedFileChunk[]; files?: ContextMatchedFileChunk[];
docs?: ContextMatchedDocChunk[]; docs?: ContextMatchedDocChunk[];
}>; }>;
applyDocUpdates: (
workspaceId: string,
docId: string,
op: string,
updates: string
) => Promise<string>;
addContextBlob: (options: { addContextBlob: (options: {
blobId: string; blobId: string;
contextId: string; contextId: string;

View File

@@ -2,7 +2,6 @@ import track from '@affine/track';
import { WithDisposable } from '@blocksuite/affine/global/lit'; import { WithDisposable } from '@blocksuite/affine/global/lit';
import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme'; import { unsafeCSSVar, unsafeCSSVarV2 } from '@blocksuite/affine/shared/theme';
import { type EditorHost, ShadowlessElement } from '@blocksuite/affine/std'; import { type EditorHost, ShadowlessElement } from '@blocksuite/affine/std';
import { LoadingIcon } from '@blocksuite/affine-components/icons';
import type { NotificationService } from '@blocksuite/affine-shared/services'; import type { NotificationService } from '@blocksuite/affine-shared/services';
import { import {
CloseIcon, CloseIcon,
@@ -17,8 +16,6 @@ import { css, html, nothing } from 'lit';
import { property, state } from 'lit/decorators.js'; import { property, state } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js'; import { repeat } from 'lit/directives/repeat.js';
import { AIProvider } from '../../provider';
import { BlockDiffProvider } from '../../services/block-diff';
import { diffMarkdown } from '../../utils/apply-model/markdown-diff'; import { diffMarkdown } from '../../utils/apply-model/markdown-diff';
import { copyText } from '../../utils/editor-actions'; import { copyText } from '../../utils/editor-actions';
import { AI_CHAT_AUTO_SCROLL_PAUSE_EVENT } from '../ai-chat-messages/auto-scroll'; import { AI_CHAT_AUTO_SCROLL_PAUSE_EVENT } from '../ai-chat-messages/auto-scroll';
@@ -218,61 +215,21 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
@state() @state()
accessor isCollapsed = false; accessor isCollapsed = false;
@state()
accessor applyingMap: Record<string, boolean> = {};
@state()
accessor acceptingMap: Record<string, boolean> = {};
get blockDiffService() {
return this.host?.std.getOptional(BlockDiffProvider);
}
get isBusy() { get isBusy() {
return undefined; return undefined;
} }
isBusyForOp(op: string) { private _handleApply(op: string) {
return this.applyingMap[op] || this.acceptingMap[op]; if (!this.host || this.data.type !== 'tool-result') {
}
private async _handleApply(op: string, updates: string) {
if (
!this.host ||
this.data.type !== 'tool-result' ||
this.isBusyForOp(op)
) {
return; return;
} }
this.applyingMap = { ...this.applyingMap, [op]: true }; track.applyModel.chat.$.apply({
try { instruction: this.data.args.instructions,
const markdown = await AIProvider.context?.applyDocUpdates( operation: op,
this.host.std.workspace.id, });
this.data.args.doc_id,
op,
updates
);
if (!markdown) {
return;
}
track.applyModel.chat.$.apply({
instruction: this.data.args.instructions,
operation: op,
});
await this.blockDiffService?.apply(this.host.store, markdown);
} catch (error) {
this.notificationService.notify({
title: 'Failed to apply updates',
message: error instanceof Error ? error.message : 'Unknown error',
accent: 'error',
onClose: function (): void {},
});
} finally {
this.applyingMap = { ...this.applyingMap, [op]: false };
}
} }
private async _handleReject(op: string) { private _handleReject(op: string) {
if (!this.host || this.data.type !== 'tool-result') { if (!this.host || this.data.type !== 'tool-result') {
return; return;
} }
@@ -281,45 +238,16 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
instruction: this.data.args.instructions, instruction: this.data.args.instructions,
operation: op, operation: op,
}); });
this.blockDiffService?.setChangedMarkdown(null);
this.blockDiffService?.rejectAll();
} }
private async _handleAccept(op: string, updates: string) { private _handleAccept(op: string) {
if ( if (!this.host || this.data.type !== 'tool-result') {
!this.host ||
this.data.type !== 'tool-result' ||
this.isBusyForOp(op)
) {
return; return;
} }
this.acceptingMap = { ...this.acceptingMap, [op]: true }; track.applyModel.chat.$.accept({
try { instruction: this.data.args.instructions,
const changedMarkdown = await AIProvider.context?.applyDocUpdates( operation: op,
this.host.std.workspace.id, });
this.data.args.doc_id,
op,
updates
);
if (!changedMarkdown) {
return;
}
track.applyModel.chat.$.accept({
instruction: this.data.args.instructions,
operation: op,
});
await this.blockDiffService?.apply(this.host.store, changedMarkdown);
await this.blockDiffService?.acceptAll(this.host.store);
} catch (error) {
this.notificationService.notify({
title: 'Failed to apply updates',
message: error instanceof Error ? error.message : 'Unknown error',
accent: 'error',
onClose: function (): void {},
});
} finally {
this.acceptingMap = { ...this.acceptingMap, [op]: false };
}
} }
private async _toggleCollapse() { private async _toggleCollapse() {
@@ -421,7 +349,7 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
return repeat( return repeat(
result.result, result.result,
change => change.op, change => change.op,
({ op, updates, originalContent, changedContent }) => { ({ op, originalContent, changedContent }) => {
const diffs = diffMarkdown(originalContent, changedContent); const diffs = diffMarkdown(originalContent, changedContent);
return html` return html`
<div class="doc-edit-tool-result-wrapper"> <div class="doc-edit-tool-result-wrapper">
@@ -449,14 +377,7 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
${CopyIcon()} ${CopyIcon()}
<affine-tooltip>Copy</affine-tooltip> <affine-tooltip>Copy</affine-tooltip>
</button> </button>
<button <button @click=${() => this._handleApply(op)}>Apply</button>
@click=${() => this._handleApply(op, updates)}
?disabled=${this.isBusyForOp(op)}
>
${this.applyingMap[op]
? html`${LoadingIcon()} Applying`
: 'Apply'}
</button>
</div> </div>
</div> </div>
<div class="doc-edit-tool-result-card-content"> <div class="doc-edit-tool-result-card-content">
@@ -473,18 +394,12 @@ export class DocEditTool extends WithDisposable(ShadowlessElement) {
</button> </button>
<button <button
class="doc-edit-tool-result-accept" class="doc-edit-tool-result-accept"
@click=${() => this._handleAccept(op, updates)} @click=${() => this._handleAccept(op)}
?disabled=${this.isBusyForOp(op)}
style="${this.isBusyForOp(op)
? 'pointer-events: none; opacity: 0.6;'
: ''}"
> >
${this.acceptingMap[op] ${DoneIcon({
? html`${LoadingIcon()}` style: `color: ${unsafeCSSVarV2('icon/activated')}`,
: DoneIcon({ })}
style: `color: ${unsafeCSSVarV2('icon/activated')}`, Accept
})}
${this.acceptingMap[op] ? 'Accepting...' : 'Accept'}
</button> </button>
</div> </div>
</div> </div>

View File

@@ -6,7 +6,6 @@ import {
addContextCategoryMutation, addContextCategoryMutation,
addContextDocMutation, addContextDocMutation,
addContextFileMutation, addContextFileMutation,
applyDocUpdatesMutation,
cleanupCopilotSessionMutation, cleanupCopilotSessionMutation,
createCopilotContextMutation, createCopilotContextMutation,
createCopilotMessageMutation, createCopilotMessageMutation,
@@ -473,6 +472,7 @@ export class CopilotClient {
actionVersion, actionVersion,
runId, runId,
retry, retry,
byokLeaseId,
}: { }: {
sessionId: string; sessionId: string;
messageId?: string; messageId?: string;
@@ -483,6 +483,7 @@ export class CopilotClient {
actionVersion?: string; actionVersion?: string;
runId?: string; runId?: string;
retry?: boolean; retry?: boolean;
byokLeaseId?: string;
}, },
endpoint = Endpoint.StreamObject endpoint = Endpoint.StreamObject
) { ) {
@@ -499,6 +500,7 @@ export class CopilotClient {
actionVersion, actionVersion,
runId, runId,
retry, retry,
byokLeaseId,
}); });
if (queryString) { if (queryString) {
url += `?${queryString}`; url += `?${queryString}`;
@@ -511,12 +513,14 @@ export class CopilotClient {
sessionId: string, sessionId: string,
messageId?: string, messageId?: string,
seed?: string, seed?: string,
endpoint = Endpoint.Images endpoint = Endpoint.Images,
byokLeaseId?: string
) { ) {
let url = `/api/copilot/chat/${sessionId}/${endpoint}`; let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
const queryString = this.paramsToQueryString({ const queryString = this.paramsToQueryString({
messageId, messageId,
seed, seed,
byokLeaseId,
}); });
if (queryString) { if (queryString) {
url += `?${queryString}`; url += `?${queryString}`;
@@ -549,23 +553,6 @@ export class CopilotClient {
}).then(res => res.queryWorkspaceEmbeddingStatus); }).then(res => res.queryWorkspaceEmbeddingStatus);
} }
applyDocUpdates(
workspaceId: string,
docId: string,
op: string,
updates: string
) {
return this.gql({
query: applyDocUpdatesMutation,
variables: {
workspaceId,
docId,
op,
updates,
},
}).then(res => res.applyDocUpdates);
}
addContextBlob(options: OptionsField<typeof addContextBlobMutation>) { addContextBlob(options: OptionsField<typeof addContextBlobMutation>) {
return this.gql({ return this.gql({
query: addContextBlobMutation, query: addContextBlobMutation,

View File

@@ -0,0 +1,284 @@
/**
* @vitest-environment happy-dom
*/
import { UserFriendlyError } from '@affine/error';
import { beforeEach, describe, expect, test, vi } from 'vitest';
import { type CopilotClient, Endpoint } from './copilot-client';
import { textToText, toImage } from './request';
const electronApis = vi.hoisted(() => ({
byokStorage: undefined as
| {
isSupported: () => Promise<boolean>;
getWorkspaceLeaseProviders: (workspaceId: string) => Promise<
Array<{
provider: string;
name: string;
apiKey: string;
description?: string | null;
endpoint?: string | null;
sortOrder?: number | null;
enabled?: boolean | null;
}>
>;
}
| undefined,
}));
const createWorkspaceByokLocalLeaseMutation = vi.hoisted(() =>
Symbol('createWorkspaceByokLocalLeaseMutation')
);
vi.mock('@affine/electron-api', () => ({
apis: electronApis,
}));
vi.mock('@affine/graphql', () => ({
ByokProvider: {
openai: 'openai',
anthropic: 'anthropic',
gemini: 'gemini',
fal: 'fal',
},
createWorkspaceByokLocalLeaseMutation,
}));
function createClient(
overrides: Partial<
Pick<
CopilotClient,
'gql' | 'createMessage' | 'chatTextStream' | 'imagesStream'
>
> = {}
) {
return {
gql: vi.fn().mockResolvedValue({
createWorkspaceByokLocalLease: { leaseId: 'lease-1' },
}),
createMessage: vi.fn().mockResolvedValue('message-1'),
chatTextStream: vi.fn(),
imagesStream: vi.fn(),
...overrides,
} as unknown as CopilotClient;
}
async function drain(stream: AsyncIterable<unknown>) {
for await (const chunk of stream) {
void chunk;
}
}
describe('AI request BYOK local lease handling', () => {
beforeEach(() => {
vi.stubGlobal('BUILD_CONFIG', { isElectron: true });
electronApis.byokStorage = {
isSupported: vi.fn().mockResolvedValue(true),
getWorkspaceLeaseProviders: vi.fn().mockResolvedValue([
{
provider: 'openai',
name: 'OpenAI',
apiKey: 'sk-local',
},
]),
};
});
test('fails closed when local BYOK providers exist but lease creation fails', async () => {
const client = createClient({
gql: vi.fn().mockRejectedValue(new Error('mutation failed')),
});
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('mutation failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('wraps local BYOK storage support failures as user friendly errors', async () => {
electronApis.byokStorage = {
isSupported: vi.fn().mockRejectedValue(new Error('support check failed')),
getWorkspaceLeaseProviders: vi.fn(),
};
const client = createClient();
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('support check failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('wraps local BYOK provider loading failures as user friendly errors', async () => {
electronApis.byokStorage = {
isSupported: vi.fn().mockResolvedValue(true),
getWorkspaceLeaseProviders: vi
.fn()
.mockRejectedValue(new Error('provider load failed')),
};
const client = createClient();
const result = textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
}) as Promise<string>;
await expect(result).rejects.toThrow('provider load failed');
await expect(result).rejects.toBeInstanceOf(UserFriendlyError);
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await expect(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
signal: controller.signal,
}) as Promise<string>
).resolves.toBe('');
expect(client.gql).not.toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create stream local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await drain(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
stream: true,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).not.toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create text stream when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await drain(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
stream: true,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create text request when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await expect(
textToText({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'hello',
signal: controller.signal,
}) as Promise<string>
).resolves.toBe('');
expect(client.gql).toHaveBeenCalled();
expect(client.chatTextStream).not.toHaveBeenCalled();
});
test('does not create image local BYOK lease after cancellation', async () => {
const controller = new AbortController();
const client = createClient({
createMessage: vi.fn().mockImplementation(async () => {
controller.abort();
return 'message-1';
}),
});
await drain(
toImage({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'image',
endpoint: Endpoint.Images,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).not.toHaveBeenCalled();
expect(client.imagesStream).not.toHaveBeenCalled();
});
test('does not create image stream when cancelled while creating local BYOK lease', async () => {
const controller = new AbortController();
const client = createClient({
gql: vi.fn().mockImplementation(async () => {
controller.abort();
return { createWorkspaceByokLocalLease: { leaseId: 'lease-1' } };
}),
});
await drain(
toImage({
client,
sessionId: 'session-1',
workspaceId: 'workspace-1',
content: 'image',
endpoint: Endpoint.Images,
signal: controller.signal,
}) as AsyncIterable<string>
);
expect(client.gql).toHaveBeenCalled();
expect(client.imagesStream).not.toHaveBeenCalled();
});
});

View File

@@ -1,4 +1,10 @@
import type { AIToolsConfig } from '@affine/core/modules/ai-button'; import type { AIToolsConfig } from '@affine/core/modules/ai-button';
import { apis, type ClientHandler } from '@affine/electron-api';
import { UserFriendlyError } from '@affine/error';
import {
ByokProvider,
createWorkspaceByokLocalLeaseMutation,
} from '@affine/graphql';
import { partition } from 'lodash-es'; import { partition } from 'lodash-es';
import { AIProvider } from './ai-provider'; import { AIProvider } from './ai-provider';
@@ -7,9 +13,99 @@ import { toTextStream } from './event-source';
const TIMEOUT = 50000; const TIMEOUT = 50000;
function isElectronBuild() {
return typeof BUILD_CONFIG !== 'undefined' && BUILD_CONFIG.isElectron;
}
function byokStorageApi(): ClientHandler['byokStorage'] | undefined {
return isElectronBuild() ? apis?.byokStorage : undefined;
}
function toGraphqlByokProvider(provider: string): ByokProvider | null {
switch (provider) {
case ByokProvider.openai:
return ByokProvider.openai;
case ByokProvider.anthropic:
return ByokProvider.anthropic;
case ByokProvider.gemini:
return ByokProvider.gemini;
case ByokProvider.fal:
return ByokProvider.fal;
default:
return null;
}
}
function errorMetadata(error: unknown) {
if (!error || typeof error !== 'object') {
return { kind: typeof error };
}
const record = error as Record<string, unknown>;
return {
name: typeof record.name === 'string' ? record.name : undefined,
code: typeof record.code === 'string' ? record.code : undefined,
status:
typeof record.status === 'number' || typeof record.status === 'string'
? record.status
: undefined,
type: typeof record.type === 'string' ? record.type : undefined,
};
}
async function createWorkspaceByokLocalLease(
client: CopilotClient,
workspaceId?: string
) {
const storage = byokStorageApi();
if (!workspaceId || !storage) {
return undefined;
}
try {
if (!(await storage.isSupported())) return undefined;
const providers = await storage.getWorkspaceLeaseProviders(workspaceId);
if (!providers.length) return undefined;
const leaseProviders = providers.flatMap(provider => {
const gqlProvider = toGraphqlByokProvider(provider.provider);
return gqlProvider
? [
{
provider: gqlProvider,
name: provider.name,
description: provider.description ?? null,
apiKey: provider.apiKey,
endpoint: provider.endpoint ?? null,
sortOrder: provider.sortOrder ?? 0,
enabled: provider.enabled ?? true,
},
]
: [];
});
if (!leaseProviders.length) return undefined;
const result = await client.gql({
query: createWorkspaceByokLocalLeaseMutation,
variables: {
input: {
workspaceId,
providers: leaseProviders,
},
},
});
return result.createWorkspaceByokLocalLease.leaseId;
} catch (error) {
console.warn(
'Failed to create workspace BYOK local lease',
errorMetadata(error)
);
throw UserFriendlyError.fromAny(error);
}
}
export type TextToTextOptions = { export type TextToTextOptions = {
client: CopilotClient; client: CopilotClient;
sessionId: string; sessionId: string;
workspaceId?: string;
content?: string; content?: string;
attachments?: (string | Blob | File)[]; attachments?: (string | Blob | File)[];
params?: Record<string, any>; params?: Record<string, any>;
@@ -114,6 +210,7 @@ async function createMessage({
export function textToText({ export function textToText({
client, client,
sessionId, sessionId,
workspaceId,
content, content,
attachments, attachments,
params, params,
@@ -145,6 +242,16 @@ export function textToText({
signal, signal,
}); });
} }
if (signal?.aborted) {
return;
}
const byokLeaseId = await createWorkspaceByokLocalLease(
client,
workspaceId
);
if (signal?.aborted) {
return;
}
const eventSource = client.chatTextStream( const eventSource = client.chatTextStream(
{ {
sessionId, sessionId,
@@ -156,6 +263,7 @@ export function textToText({
actionVersion, actionVersion,
runId, runId,
retry, retry,
byokLeaseId,
}, },
endpoint endpoint
); );
@@ -203,6 +311,16 @@ export function textToText({
signal, signal,
}); });
} }
if (signal?.aborted) {
return '';
}
const byokLeaseId = await createWorkspaceByokLocalLease(
client,
workspaceId
);
if (signal?.aborted) {
return '';
}
const eventSource = client.chatTextStream( const eventSource = client.chatTextStream(
{ {
sessionId, sessionId,
@@ -214,6 +332,7 @@ export function textToText({
actionVersion, actionVersion,
runId, runId,
retry, retry,
byokLeaseId,
}, },
endpoint endpoint
); );
@@ -258,6 +377,7 @@ export function textToText({
export function toImage({ export function toImage({
content, content,
sessionId, sessionId,
workspaceId,
attachments, attachments,
params, params,
seed, seed,
@@ -284,6 +404,16 @@ export function toImage({
signal, signal,
}); });
} }
if (signal?.aborted) {
return;
}
const byokLeaseId = await createWorkspaceByokLocalLease(
client,
workspaceId
);
if (signal?.aborted) {
return;
}
const eventSource = const eventSource =
endpoint === Endpoint.Action endpoint === Endpoint.Action
? client.chatTextStream( ? client.chatTextStream(
@@ -294,10 +424,17 @@ export function toImage({
actionVersion, actionVersion,
runId, runId,
retry, retry,
byokLeaseId,
}, },
Endpoint.Action Endpoint.Action
) )
: client.imagesStream(sessionId, messageId, seed, endpoint); : client.imagesStream(
sessionId,
messageId,
seed,
endpoint,
byokLeaseId
);
AIProvider.LAST_ACTION_SESSIONID = sessionId; AIProvider.LAST_ACTION_SESSIONID = sessionId;
for await (const event of toTextStream(eventSource, { for await (const event of toTextStream(eventSource, {

View File

@@ -722,14 +722,6 @@ Could you make a new website based on these notes and send back just the html fi
threshold threshold
); );
}, },
applyDocUpdates: async (
workspaceId: string,
docId: string,
op: string,
updates: string
) => {
return client.applyDocUpdates(workspaceId, docId, op, updates);
},
addContextBlob: async (options: { blobId: string; contextId: string }) => { addContextBlob: async (options: { blobId: string; contextId: string }) => {
return client.addContextBlob({ return client.addContextBlob({
contextId: options.contextId, contextId: options.contextId,

View File

@@ -0,0 +1,346 @@
import { Button, Modal, notify } from '@affine/component';
import {
ByokKeyStorage,
ByokProvider,
testWorkspaceByokConfigMutation as testByokMutation,
upsertWorkspaceByokConfigMutation as upsertByokMutation,
} from '@affine/graphql';
import { useI18n } from '@affine/i18n';
import { useCallback, useEffect, useState } from 'react';
import { logByokError } from './errors';
import * as styles from './index.css';
import { readLocalKeys, upsertLocalKey } from './local-storage';
import { byokT, providerLabels, storageLabel } from './metadata';
import type {
ByokKey,
ByokSettings,
ByokStorage,
ByokTestResult,
GqlFn,
} from './types';
export const AddKeyModal = ({
workspaceId,
settings,
editingKey,
open,
onOpenChange,
onSaved,
localKeys,
setLocalKeys,
localStorageSupported,
canAddServerKey,
canAddLocalKey,
gql,
}: {
workspaceId: string;
settings: ByokSettings;
editingKey: ByokKey | null;
open: boolean;
onOpenChange: (open: boolean) => void;
onSaved: () => Promise<void>;
localKeys: ByokKey[];
setLocalKeys: (keys: ByokKey[]) => void;
localStorageSupported: boolean;
canAddServerKey: boolean;
canAddLocalKey: boolean;
gql?: GqlFn;
}) => {
const t = useI18n();
const [provider, setProvider] = useState<ByokProvider>(ByokProvider.openai);
const [name, setName] = useState('');
const [description, setDescription] = useState('');
const [storage, setStorage] = useState<ByokStorage>(ByokKeyStorage.server);
const [apiKey, setApiKey] = useState('');
const [endpoint, setEndpoint] = useState('');
const [testResult, setTestResult] = useState<ByokTestResult | null>(null);
const [testing, setTesting] = useState(false);
const canTestStoredConfig =
storage === ByokKeyStorage.server &&
editingKey?.storage === ByokKeyStorage.server &&
editingKey.provider === provider;
const canTest = !!apiKey || canTestStoredConfig;
useEffect(() => {
if (!open) {
return;
}
setProvider(editingKey?.provider ?? ByokProvider.openai);
setName(editingKey?.name ?? '');
setDescription(editingKey?.description ?? '');
setStorage(
editingKey?.storage ??
(canAddServerKey ? ByokKeyStorage.server : ByokKeyStorage.local)
);
setApiKey('');
setEndpoint(editingKey?.endpoint ?? '');
setTestResult(null);
}, [canAddServerKey, editingKey, open]);
const testKey = useCallback(async () => {
if (!gql) {
return;
}
setTesting(true);
try {
const result = await gql({
query: testByokMutation,
variables: {
input: {
workspaceId,
provider,
storage,
apiKey: apiKey || null,
endpoint: endpoint || null,
configId: canTestStoredConfig ? editingKey.id : null,
},
},
});
const nextResult = result.testWorkspaceByokConfig as
| ByokTestResult
| undefined;
setTestResult(nextResult ?? null);
if (nextResult && !nextResult.ok) {
notify.error({
title: byokT(t, 'notify.test-failed.title'),
message: nextResult.message,
});
}
} finally {
setTesting(false);
}
}, [
apiKey,
canTestStoredConfig,
editingKey,
endpoint,
gql,
provider,
storage,
t,
workspaceId,
]);
const save = useCallback(async () => {
if (!testResult?.ok || !gql) {
return;
}
if (storage === ByokKeyStorage.local) {
const saved = await upsertLocalKey(workspaceId, {
id:
editingKey?.storage === ByokKeyStorage.local
? editingKey.id
: crypto.randomUUID(),
provider,
name,
description,
apiKey,
endpoint: endpoint || null,
sortOrder:
editingKey?.storage === ByokKeyStorage.local
? editingKey.sortOrder
: localKeys.length,
enabled: true,
});
if (!saved) {
notify.error({
title: byokT(t, 'notify.local-save-failed.title'),
message: byokT(t, 'notify.local-save-failed.message'),
});
return;
}
setLocalKeys(await readLocalKeys(workspaceId));
} else {
await gql({
query: upsertByokMutation,
variables: {
input: {
workspaceId,
id:
editingKey?.storage === ByokKeyStorage.server
? editingKey.id
: null,
provider,
name,
description,
storage,
apiKey: apiKey || null,
endpoint: endpoint || null,
enabled: true,
},
},
});
await onSaved();
}
onOpenChange(false);
setApiKey('');
setTestResult(null);
}, [
apiKey,
description,
editingKey,
endpoint,
gql,
localKeys,
name,
onOpenChange,
onSaved,
provider,
setLocalKeys,
storage,
t,
testResult?.ok,
workspaceId,
]);
return (
<Modal
width={520}
open={open}
onOpenChange={onOpenChange}
title={
editingKey ? byokT(t, 'modal.edit-title') : byokT(t, 'modal.add-title')
}
description={byokT(t, 'modal.description')}
>
<div className={styles.form}>
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.provider')}</span>
<select
className={styles.input}
value={provider}
onChange={event => {
setProvider(event.target.value as ByokProvider);
setTestResult(null);
}}
>
{settings.allowedProviders.map(provider => (
<option key={provider} value={provider}>
{providerLabels[provider]}
</option>
))}
</select>
</label>
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.key-name')}</span>
<input
className={styles.input}
value={name}
onChange={event => setName(event.target.value)}
placeholder={byokT(t, 'placeholder.key-name')}
/>
</label>
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.description')}</span>
<input
className={styles.input}
value={description}
onChange={event => setDescription(event.target.value)}
placeholder={byokT(t, 'placeholder.description')}
/>
</label>
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.storage')}</span>
<select
className={styles.input}
value={storage}
disabled={!!editingKey}
onChange={event => {
setStorage(event.target.value as ByokStorage);
setTestResult(null);
}}
>
<option value={ByokKeyStorage.server} disabled={!canAddServerKey}>
{storageLabel(t, ByokKeyStorage.server)}
</option>
<option
value={ByokKeyStorage.local}
disabled={!localStorageSupported || !canAddLocalKey}
>
{canAddLocalKey
? byokT(t, 'storage.local-this-device')
: byokT(t, 'storage.local-desktop-only')}
</option>
</select>
</label>
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.api-key')}</span>
<input
className={styles.input}
value={apiKey}
onChange={event => {
setApiKey(event.target.value);
setTestResult(null);
}}
type="password"
/>
</label>
{settings.customEndpointSupported ? (
<label className={styles.field}>
<span className={styles.label}>{byokT(t, 'field.endpoint')}</span>
<input
className={styles.input}
value={endpoint}
onChange={event => {
setEndpoint(event.target.value);
setTestResult(null);
}}
placeholder="https://api.example.com/v1"
/>
</label>
) : null}
<div className={styles.modalActions}>
<span
className={`${styles.testStatus} ${
testResult?.ok
? styles.success
: testResult && !testResult.ok
? styles.error
: ''
}`}
>
{testResult?.ok
? byokT(t, 'status.key-verified')
: testResult
? byokT(t, 'status.key-test-failed')
: ''}
</span>
<Button
variant="secondary"
disabled={!canTest || testing}
onClick={() => {
testKey().catch(error => {
logByokError('Failed to test BYOK key', error);
notify.error({
title: byokT(t, 'notify.test-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}}
>
{byokT(t, 'action.test-key')}
</Button>
<Button variant="secondary" onClick={() => onOpenChange(false)}>
{byokT(t, 'action.cancel')}
</Button>
<Button
variant="primary"
disabled={!testResult?.ok || !name}
onClick={() => {
save().catch(error => {
logByokError('Failed to save BYOK key', error);
notify.error({
title: byokT(t, 'notify.save-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}}
>
{byokT(t, 'action.save-key')}
</Button>
</div>
</div>
</Modal>
);
};

View File

@@ -0,0 +1,98 @@
import { useI18n } from '@affine/i18n';
import {
ChatWithAiIcon,
ImageIcon,
PenIcon,
TocIcon,
TranscriptWithAiIcon,
} from '@blocksuite/icons/rc';
import type { ReactNode } from 'react';
import * as styles from './index.css';
import { byokT, capabilityRows, warningDescription } from './metadata';
import type { ByokKey, ByokSettings } from './types';
function coverageIcon(
icon: (typeof capabilityRows)[number]['icon']
): ReactNode {
switch (icon) {
case 'chat':
return <ChatWithAiIcon className={styles.capabilityIconSvg} />;
case 'action':
return <PenIcon className={styles.capabilityIconSvg} />;
case 'image':
return <ImageIcon className={styles.capabilityIconSvg} />;
case 'transcript':
return <TranscriptWithAiIcon className={styles.capabilityIconSvg} />;
case 'indexing':
return <TocIcon className={styles.capabilityIconSvg} />;
}
}
function isRowCovered(row: (typeof capabilityRows)[number], keys: ByokKey[]) {
if (!row.coverageCapabilities.length) {
return false;
}
return keys.some(key => {
if (!key.configured || !key.enabled) {
return false;
}
return (
(!('storage' in row) || row.storage === key.storage) &&
row.coverageCapabilities.every(capability =>
key.capabilities.includes(capability)
)
);
});
}
export const CoveragePanel = ({
keys,
settings,
}: {
keys: ByokKey[];
settings: ByokSettings;
}) => {
const t = useI18n();
return (
<div className={styles.panel}>
<div className={styles.panelHeader}>
<div className={styles.title}>{byokT(t, 'coverage.title')}</div>
</div>
<div className={styles.rows}>
{capabilityRows.map(row => {
const warning = settings.warnings.find(
w => w.featureKind === row.featureKind
);
const covered = isRowCovered(row, keys);
return (
<div
className={`${styles.row} ${styles.capabilityRow} ${
covered ? '' : styles.capabilityRowInactive
}`}
data-covered={covered}
data-testid={`workspace-byok-coverage-${row.featureKind}`}
key={row.featureKind}
>
<div
className={`${styles.capabilityIcon} ${
covered ? styles.capabilityIconActive : ''
}`}
>
{coverageIcon(row.icon)}
</div>
<div className={styles.rowMain}>
<div className={styles.rowTitle}>{byokT(t, row.titleKey)}</div>
<div className={styles.rowDescription}>
{warningDescription(t, warning) ?? byokT(t, row.fallbackKey)}
</div>
</div>
</div>
);
})}
</div>
</div>
);
};

View File

@@ -0,0 +1,19 @@
function errorMetadata(error: unknown) {
if (!error || typeof error !== 'object') {
return { kind: typeof error };
}
const record = error as Record<string, unknown>;
return {
name: typeof record.name === 'string' ? record.name : undefined,
code: typeof record.code === 'string' ? record.code : undefined,
status:
typeof record.status === 'number' || typeof record.status === 'string'
? record.status
: undefined,
type: typeof record.type === 'string' ? record.type : undefined,
};
}
export function logByokError(context: string, error: unknown) {
console.warn(context, errorMetadata(error));
}

View File

@@ -0,0 +1,251 @@
import { cssVar } from '@toeverything/theme';
import { cssVarV2 } from '@toeverything/theme/v2';
import { style } from '@vanilla-extract/css';
export const stack = style({
display: 'flex',
flexDirection: 'column',
gap: 24,
});
export const panel = style({
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
borderRadius: 8,
overflow: 'hidden',
background: cssVarV2('layer/background/primary'),
});
export const panelHeader = style({
display: 'flex',
alignItems: 'center',
justifyContent: 'space-between',
gap: 12,
padding: '12px 16px',
borderBottom: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
});
export const title = style({
fontSize: cssVar('fontSm'),
fontWeight: 600,
color: cssVarV2('text/primary'),
});
export const description = style({
fontSize: cssVar('fontXs'),
lineHeight: '20px',
color: cssVarV2('text/secondary'),
});
export const empty = style({
display: 'flex',
flexDirection: 'column',
alignItems: 'center',
gap: 4,
padding: '28px 20px',
textAlign: 'center',
});
export const rows = style({
display: 'flex',
flexDirection: 'column',
});
export const row = style({
display: 'grid',
gridTemplateColumns: '24px 1fr auto',
alignItems: 'center',
gap: 12,
padding: '12px 16px',
borderBottom: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
selectors: {
'&:last-child': {
borderBottom: 0,
},
},
});
export const capabilityRow = style({
gridTemplateColumns: '32px 1fr',
});
export const capabilityRowInactive = style({
opacity: 0.48,
background: cssVarV2('layer/background/secondary'),
});
export const capabilityIcon = style({
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
width: 24,
height: 24,
borderRadius: 6,
color: cssVarV2('text/secondary'),
background: cssVarV2('layer/background/secondary'),
});
export const capabilityIconActive = style({
color: cssVarV2('button/primary'),
background: '#f0f7ff',
});
export const capabilityIconSvg = style({
width: 16,
height: 16,
});
export const rowDisabled = style({
opacity: 0.55,
background: cssVarV2('layer/background/secondary'),
});
export const dragHandle = style({
color: cssVarV2('text/secondary'),
cursor: 'grab',
textAlign: 'center',
});
export const rowMain = style({
minWidth: 0,
display: 'flex',
flexDirection: 'column',
gap: 4,
});
export const rowTitle = style({
display: 'flex',
alignItems: 'center',
gap: 8,
minWidth: 0,
fontSize: cssVar('fontSm'),
fontWeight: 600,
color: cssVarV2('text/primary'),
});
export const rowDescription = style({
overflow: 'hidden',
textOverflow: 'ellipsis',
whiteSpace: 'nowrap',
fontSize: cssVar('fontXs'),
color: cssVarV2('text/secondary'),
});
export const tags = style({
display: 'flex',
flexWrap: 'wrap',
gap: 6,
});
export const tag = style({
borderRadius: 999,
padding: '2px 8px',
fontSize: 11,
lineHeight: '16px',
color: cssVarV2('text/secondary'),
background: cssVarV2('layer/background/secondary'),
});
export const dangerTag = style({
color: '#b42318',
background: '#fff5f5',
});
export const rowActions = style({
display: 'flex',
alignItems: 'center',
gap: 8,
});
export const notice = style({
borderRadius: 8,
padding: 12,
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
background: cssVarV2('layer/background/secondary'),
});
export const locked = style({
display: 'flex',
flexDirection: 'column',
gap: 16,
padding: 24,
borderRadius: 8,
background: cssVarV2('layer/background/secondary'),
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
});
export const form = style({
display: 'flex',
flexDirection: 'column',
gap: 12,
});
export const field = style({
display: 'flex',
flexDirection: 'column',
gap: 4,
height: '3em',
});
export const label = style({
fontSize: cssVar('fontXs'),
color: cssVarV2('text/secondary'),
});
export const input = style({
height: 32,
width: '100%',
boxSizing: 'border-box',
borderRadius: 8,
border: `1px solid ${cssVarV2('layer/insideBorder/border')}`,
padding: '0 10px',
fontSize: cssVar('fontSm'),
lineHeight: '22px',
background: cssVarV2('layer/background/primary'),
color: cssVarV2('text/primary'),
outline: 'none',
selectors: {
'&::placeholder': {
color: cssVarV2('text/placeholder'),
},
'&:focus': {
borderColor: cssVarV2('button/primary'),
boxShadow: '0px 0px 0px 2px rgba(30, 150, 235, 0.30)',
},
},
});
export const modalActions = style({
display: 'flex',
alignItems: 'center',
justifyContent: 'flex-end',
gap: 8,
marginTop: 8,
});
export const testStatus = style({
marginRight: 'auto',
fontSize: cssVar('fontXs'),
});
export const error = style({
color: '#b42318',
});
export const success = style({
color: '#168a58',
});
export const chart = style({
display: 'grid',
gridTemplateColumns: 'repeat(30, minmax(4px, 1fr))',
alignItems: 'end',
gap: 4,
height: 140,
padding: '16px 16px 24px',
});
export const bar = style({
minHeight: 2,
borderRadius: '4px 4px 0 0',
background: '#5b8def',
});

View File

@@ -0,0 +1,700 @@
/**
* @vitest-environment happy-dom
*/
import {
cleanup,
fireEvent,
render,
screen,
waitFor,
} from '@testing-library/react';
import type * as Infra from '@toeverything/infra';
import type { ButtonHTMLAttributes, ReactNode } from 'react';
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest';
const gqlMock = vi.hoisted(() => vi.fn());
const workspaceState = vi.hoisted(() => ({
id: 'workspace-1',
}));
const electronApiState = vi.hoisted(() => ({
apis: undefined as
| {
byokStorage?: {
isSupported: () => Promise<boolean>;
listWorkspaceKeys: (workspaceId: string) => Promise<unknown[]>;
};
}
| undefined,
}));
const WorkspaceServerServiceToken = vi.hoisted(
() => class WorkspaceServerService {}
);
const WorkspaceServiceToken = vi.hoisted(() => class WorkspaceService {});
const ByokProvider = vi.hoisted(() => ({
openai: 'openai',
anthropic: 'anthropic',
gemini: 'gemini',
fal: 'fal',
}));
const ByokKeyStorage = vi.hoisted(() => ({
server: 'server',
local: 'local',
}));
const ByokKeyTestStatus = vi.hoisted(() => ({
untested: 'untested',
passed: 'passed',
failed: 'failed',
}));
const workspaceByokSettingsQuery = vi.hoisted(() =>
Symbol('workspaceByokSettingsQuery')
);
const testWorkspaceByokConfigMutation = vi.hoisted(() =>
Symbol('testWorkspaceByokConfigMutation')
);
const upsertWorkspaceByokConfigMutation = vi.hoisted(() =>
Symbol('upsertWorkspaceByokConfigMutation')
);
const clearWorkspaceByokConfigsMutation = vi.hoisted(() =>
Symbol('clearWorkspaceByokConfigsMutation')
);
const deleteWorkspaceByokConfigMutation = vi.hoisted(() =>
Symbol('deleteWorkspaceByokConfigMutation')
);
vi.mock('@affine/component', () => ({
Button: ({
children,
...props
}: ButtonHTMLAttributes<HTMLButtonElement> & { children: ReactNode }) => (
<button {...props}>{children}</button>
),
DragHandle: () => <span>drag-handle</span>,
IconButton: ({ title, onClick }: { title: string; onClick?: () => void }) => (
<button onClick={onClick}>{title}</button>
),
Modal: ({
open,
title,
children,
}: {
open: boolean;
title: string;
children: ReactNode;
}) =>
open ? (
<div role="dialog" aria-label={title}>
{children}
</div>
) : null,
notify: {
error: vi.fn(),
},
}));
vi.mock('@affine/component/setting-components', () => ({
SettingHeader: ({
title,
subtitle,
}: {
title: string;
subtitle?: string;
}) => (
<header>
<h1>{title}</h1>
{subtitle ? <p>{subtitle}</p> : null}
</header>
),
SettingWrapper: ({ children }: { children: ReactNode }) => (
<main>{children}</main>
),
}));
vi.mock('@affine/core/modules/cloud', () => ({
WorkspaceServerService: WorkspaceServerServiceToken,
}));
vi.mock('@affine/core/modules/workspace', () => ({
WorkspaceService: WorkspaceServiceToken,
}));
vi.mock('@affine/electron-api', () => ({
get apis() {
return electronApiState.apis;
},
}));
vi.mock('@affine/graphql', () => ({
ByokKeyStorage,
ByokKeyTestStatus,
ByokProvider,
clearWorkspaceByokConfigsMutation,
deleteWorkspaceByokConfigMutation,
testWorkspaceByokConfigMutation,
upsertWorkspaceByokConfigMutation,
workspaceByokSettingsQuery,
}));
vi.mock('@affine/i18n', () => {
const messages: Record<string, string> = {
'com.affine.settings.workspace.byok.action.add-key': 'Add key',
'com.affine.settings.workspace.byok.action.edit': 'Edit',
'com.affine.settings.workspace.byok.action.delete': 'Delete',
'com.affine.settings.workspace.byok.action.test-key': 'Test key',
'com.affine.settings.workspace.byok.action.save-key': 'Save key',
'com.affine.settings.workspace.byok.action.cancel': 'Cancel',
'com.affine.settings.workspace.byok.action.clear-all':
'Clear all BYOK keys',
'com.affine.settings.workspace.byok.field.api-key': 'API key',
'com.affine.settings.workspace.byok.field.storage': 'Key storage',
'com.affine.settings.workspace.byok.placeholder.key-name': 'Primary',
'com.affine.settings.workspace.byok.status.key-verified': 'Key verified',
'com.affine.settings.workspace.byok.status.disabled-after-failure':
'Disabled after failure',
'com.affine.settings.workspace.byok.storage.local': 'Local',
'com.affine.settings.workspace.byok.storage.server': 'Server',
'com.affine.settings.workspace.byok.storage.local-this-device':
'Local (this device)',
'com.affine.settings.workspace.byok.storage.local-desktop-only':
'Local (Desktop only)',
'com.affine.settings.workspace.byok.usage.tokens': '{{count}} tokens',
'com.affine.settings.workspace.byok.notify.operation-failed.message':
'Please try again.',
'com.affine.settings.workspace.byok.notify.test-failed.title':
'Key test failed',
'com.affine.settings.workspace.byok.notify.load-failed.title':
'BYOK settings not loaded',
'com.affine.settings.workspace.byok.notify.save-failed.title':
'BYOK key not saved',
'com.affine.settings.workspace.byok.notify.delete-failed.title':
'BYOK key not deleted',
'com.affine.settings.workspace.byok.notify.reorder-failed.title':
'BYOK keys not reordered',
'com.affine.settings.workspace.byok.notify.clear-failed.title':
'BYOK keys not cleared',
};
const translate = (key: string, options?: Record<string, unknown>) => {
let message = messages[key] ?? key;
for (const [name, value] of Object.entries(options ?? {})) {
message = message.replaceAll(`{{${name}}}`, String(value));
}
return message;
};
const t = new Proxy(
{
t: translate,
},
{
get(target, key: string) {
if (key in target) {
return target[key as keyof typeof target];
}
return (options?: Record<string, unknown>) => translate(key, options);
},
}
);
return {
useI18n: () => t,
};
});
vi.mock('@blocksuite/icons/rc', () => ({
ChatWithAiIcon: () => <span>chat-ai</span>,
DeleteIcon: () => <span>delete</span>,
EditIcon: () => <span>edit</span>,
ImageIcon: () => <span>image</span>,
PenIcon: () => <span>pen</span>,
TocIcon: () => <span>toc</span>,
TranscriptWithAiIcon: () => <span>transcript</span>,
}));
vi.mock('@toeverything/infra', async importOriginal => {
const actual = await importOriginal<typeof Infra>();
return {
...actual,
useService: (token: unknown) => {
if (token === WorkspaceServerServiceToken) {
return {
server: {
gql: gqlMock,
},
};
}
if (token === WorkspaceServiceToken) {
return {
workspace: workspaceState,
};
}
return {};
},
};
});
import { WorkspaceByokSetting } from '.';
import { logByokError } from './errors';
import { UsagePanel } from './usage';
function settings(overrides: Record<string, unknown> = {}) {
return {
workspaceId: 'workspace-1',
entitled: true,
serverEntitled: true,
localEntitled: false,
entitlementRequired: ['Pro', 'Team', 'Believer'],
allowedProviders: ['openai', 'anthropic', 'gemini', 'fal'],
localStorageSupported: false,
customEndpointSupported: false,
hasAiPlan: true,
keys: [],
warnings: [],
...overrides,
};
}
function byokKey(overrides: Record<string, unknown> = {}) {
return {
id: 'server-key',
provider: ByokProvider.openai,
name: 'Primary',
description: 'Workspace fallback key',
storage: ByokKeyStorage.server,
configured: true,
enabled: true,
endpoint: null,
endpointEditable: false,
sortOrder: 0,
capabilities: ['Text', 'Image input', 'Actions', 'Image generate'],
testStatus: ByokKeyTestStatus.passed,
disabledReason: null,
lastTestedAt: null,
lastTestError: null,
lastUsedAt: null,
lastErrorAt: null,
lastError: null,
...overrides,
};
}
function settingsResponse(overrides: Record<string, unknown> = {}) {
return {
workspace: {
byokSettings: settings(overrides),
byokUsage: [],
},
};
}
describe('WorkspaceByokSetting', () => {
afterEach(() => {
cleanup();
vi.unstubAllGlobals();
});
beforeEach(() => {
gqlMock.mockReset();
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse();
}
throw new Error('Unexpected GraphQL operation');
});
vi.stubGlobal('BUILD_CONFIG', { isElectron: false });
electronApiState.apis = undefined;
});
test('renders locked state without key management controls', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
entitled: false,
serverEntitled: false,
localEntitled: false,
});
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
await screen.findByTestId('workspace-byok-locked');
expect(screen.queryByText('Add key')).toBeNull();
expect(screen.queryByTestId('workspace-byok-empty')).toBeNull();
});
test('renders empty state and keeps save disabled until key test passes', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse();
}
if (query === testWorkspaceByokConfigMutation) {
return {
testWorkspaceByokConfig: {
ok: true,
status: 'passed',
message: null,
},
};
}
if (query === upsertWorkspaceByokConfigMutation) {
return { upsertWorkspaceByokConfig: { id: 'server-key' } };
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
await screen.findByTestId('workspace-byok-empty');
fireEvent.click(screen.getAllByText('Add key')[0]);
expect(screen.getByText<HTMLButtonElement>('Save key').disabled).toBe(true);
fireEvent.change(screen.getByPlaceholderText('Primary'), {
target: { value: 'Primary' },
});
fireEvent.change(screen.getByLabelText('API key'), {
target: { value: 'sk-test' },
});
fireEvent.click(screen.getByText('Test key'));
await screen.findByText('Key verified');
expect(screen.getByText<HTMLButtonElement>('Save key').disabled).toBe(
false
);
fireEvent.click(screen.getByText('Save key'));
await waitFor(() => {
expect(gqlMock).toHaveBeenCalledWith(
expect.objectContaining({
query: upsertWorkspaceByokConfigMutation,
})
);
});
});
test('keeps local storage disabled on web even for local-entitled users', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
localEntitled: true,
localStorageSupported: true,
});
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
await screen.findByTestId('workspace-byok-empty');
fireEvent.click(screen.getAllByText('Add key')[0]);
const storageSelect =
screen.getByLabelText<HTMLSelectElement>('Key storage');
const localOption = Array.from(storageSelect.options).find(
option => option.value === ByokKeyStorage.local
);
expect(localOption?.disabled).toBe(true);
});
test('reorders server keys within their storage bucket', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
keys: [
byokKey({ id: 'server-1', name: 'First', sortOrder: 0 }),
byokKey({ id: 'server-2', name: 'Second', sortOrder: 1 }),
],
});
}
return {};
});
render(<WorkspaceByokSetting />);
const firstRow = (await screen.findByText('OpenAI / First')).closest(
'[draggable="true"]'
);
const secondRow = screen
.getByText('OpenAI / Second')
.closest('[draggable="true"]');
expect(firstRow).not.toBeNull();
expect(secondRow).not.toBeNull();
fireEvent.dragStart(firstRow as Element);
fireEvent.dragOver(secondRow as Element);
fireEvent.drop(secondRow as Element);
await waitFor(() => {
expect(gqlMock).toHaveBeenCalledWith(
expect.objectContaining({
variables: expect.objectContaining({
input: expect.objectContaining({
workspaceId: 'workspace-1',
storage: ByokKeyStorage.server,
ids: ['server-2', 'server-1'],
}),
}),
})
);
});
});
test('marks coverage rows by configured provider support', async () => {
let keys = [
byokKey({ provider: ByokProvider.openai }),
byokKey({
id: 'disabled-gemini',
provider: ByokProvider.gemini,
enabled: false,
capabilities: [
'Text',
'Image input',
'Actions',
'Image generate',
'Transcript',
'Indexing',
],
}),
byokKey({
id: 'local-gemini',
provider: ByokProvider.gemini,
storage: ByokKeyStorage.local,
capabilities: ['Text', 'Image input', 'Actions', 'Image generate'],
}),
];
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
keys,
});
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
expect(
(await screen.findByTestId('workspace-byok-coverage-chat')).dataset
.covered
).toBe('true');
expect(
screen.getByTestId('workspace-byok-coverage-action').dataset.covered
).toBe('true');
expect(
screen.getByTestId('workspace-byok-coverage-image').dataset.covered
).toBe('true');
expect(
screen.getByTestId('workspace-byok-coverage-transcript').dataset.covered
).toBe('false');
expect(
screen.getByTestId('workspace-byok-coverage-workspace_indexing').dataset
.covered
).toBe('false');
expect(screen.getAllByTestId(/^workspace-byok-coverage-/)).toHaveLength(5);
cleanup();
keys = [
byokKey({
provider: ByokProvider.gemini,
capabilities: [
'Text',
'Image input',
'Actions',
'Image generate',
'Transcript',
'Indexing',
],
}),
];
render(<WorkspaceByokSetting />);
expect(
(await screen.findByTestId('workspace-byok-coverage-transcript')).dataset
.covered
).toBe('true');
expect(
screen.getByTestId('workspace-byok-coverage-workspace_indexing').dataset
.covered
).toBe('true');
});
test('restores a failed server row after key test passes', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
keys: [
byokKey({
enabled: false,
testStatus: ByokKeyTestStatus.failed,
disabledReason: 'recent_failure',
lastErrorAt: '2026-05-01T00:00:00.000Z',
lastError: 'Provider rejected the API key.',
}),
],
});
}
if (query === testWorkspaceByokConfigMutation) {
return {
testWorkspaceByokConfig: {
ok: true,
status: 'passed',
message: null,
},
};
}
if (query === upsertWorkspaceByokConfigMutation) {
return {
upsertWorkspaceByokConfig: {
id: 'server-key',
},
};
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
await screen.findByText('Disabled after failure');
fireEvent.click(screen.getByText('Edit'));
fireEvent.change(screen.getByLabelText('API key'), {
target: { value: 'sk-test' },
});
fireEvent.click(screen.getByText('Test key'));
await screen.findByText('Key verified');
fireEvent.click(screen.getByText('Save key'));
await waitFor(() => {
expect(gqlMock).toHaveBeenCalledWith(
expect.objectContaining({
query: upsertWorkspaceByokConfigMutation,
variables: expect.objectContaining({
input: expect.objectContaining({
id: 'server-key',
enabled: true,
}),
}),
})
);
});
});
test('tests a saved server key without resending plaintext', async () => {
gqlMock.mockImplementation(async ({ query }) => {
if (query === workspaceByokSettingsQuery) {
return settingsResponse({
keys: [byokKey()],
});
}
if (query === testWorkspaceByokConfigMutation) {
return {
testWorkspaceByokConfig: {
ok: true,
status: 'passed',
message: null,
},
};
}
if (query === upsertWorkspaceByokConfigMutation) {
return {
upsertWorkspaceByokConfig: {
id: 'server-key',
},
};
}
throw new Error('Unexpected GraphQL operation');
});
render(<WorkspaceByokSetting />);
await screen.findByText('OpenAI / Primary');
fireEvent.click(screen.getByText('Edit'));
expect(screen.getByText<HTMLButtonElement>('Test key').disabled).toBe(
false
);
fireEvent.click(screen.getByText('Test key'));
await waitFor(() => {
expect(gqlMock).toHaveBeenCalledWith(
expect.objectContaining({
query: testWorkspaceByokConfigMutation,
variables: expect.objectContaining({
input: expect.objectContaining({
apiKey: null,
configId: 'server-key',
}),
}),
})
);
});
await screen.findByText('Key verified');
fireEvent.click(screen.getByText('Save key'));
await waitFor(() => {
expect(gqlMock).toHaveBeenCalledWith(
expect.objectContaining({
query: upsertWorkspaceByokConfigMutation,
variables: expect.objectContaining({
input: expect.objectContaining({
apiKey: null,
id: 'server-key',
}),
}),
})
);
});
});
});
describe('UsagePanel', () => {
afterEach(() => {
cleanup();
});
test('aggregates usage rows by date before rendering bars', () => {
const today = new Date().toISOString();
render(
<UsagePanel
keys={[]}
usage={[
{ date: today, featureKind: 'chat', totalTokens: 3 },
{ date: today, featureKind: 'transcript', totalTokens: 5 },
]}
onClearAll={() => {}}
/>
);
expect(screen.getByTitle('8 tokens')).not.toBeNull();
});
});
describe('logByokError', () => {
test('logs safe metadata without raw error message', () => {
const warn = vi.spyOn(console, 'warn').mockImplementation(() => {});
const error = Object.assign(
new Error('authorization: Bearer token=a+b%2F=='),
{
code: 'BAD_REQUEST',
status: 400,
type: 'bad_request',
}
);
try {
logByokError('byok', error);
expect(warn).toHaveBeenCalledWith('byok', {
name: 'Error',
code: 'BAD_REQUEST',
status: 400,
type: 'bad_request',
});
expect(JSON.stringify(warn.mock.calls)).not.toContain('token=a+b%2F==');
} finally {
warn.mockRestore();
}
});
});

View File

@@ -0,0 +1,363 @@
import { Button, notify } from '@affine/component';
import {
SettingHeader,
SettingWrapper,
} from '@affine/component/setting-components';
import { WorkspaceServerService } from '@affine/core/modules/cloud';
import { WorkspaceService } from '@affine/core/modules/workspace';
import {
ByokKeyStorage,
clearWorkspaceByokConfigsMutation as clearByokMutation,
deleteWorkspaceByokConfigMutation as deleteByokMutation,
type GraphQLQuery,
workspaceByokSettingsQuery as byokSettingsQuery,
} from '@affine/graphql';
import { useI18n } from '@affine/i18n';
import { useService } from '@toeverything/infra';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { AddKeyModal } from './add-key-modal';
import { CoveragePanel } from './coverage';
import { logByokError } from './errors';
import * as styles from './index.css';
import { KeyList } from './key-list';
import {
clearLocalKeys,
deleteLocalKey,
localByokStorageSupported,
readLocalKeys,
reorderLocalKeys,
} from './local-storage';
import { byokT } from './metadata';
import type {
ByokKey,
ByokSettings,
ByokStorage,
ByokUsagePoint,
GqlFn,
} from './types';
import { UsagePanel } from './usage';
const reorderByokMutation = {
id: 'reorderWorkspaceByokConfigsMutation',
op: 'reorderWorkspaceByokConfigs',
query: `mutation reorderWorkspaceByokConfigs($input: ReorderWorkspaceByokConfigsInput!) {
reorderWorkspaceByokConfigs(input: $input) {
id
sortOrder
}
}`,
} satisfies GraphQLQuery;
export const WorkspaceByokSetting = () => {
const t = useI18n();
const workspace = useService(WorkspaceService).workspace;
const workspaceServer = useService(WorkspaceServerService);
const [settings, setSettings] = useState<ByokSettings | null>(null);
const [usage, setUsage] = useState<ByokUsagePoint[]>([]);
const [localKeys, setLocalKeys] = useState<ByokKey[]>([]);
const [modalOpen, setModalOpen] = useState(false);
const [editingKey, setEditingKey] = useState<ByokKey | null>(null);
const [draggingKey, setDraggingKey] = useState<{
id: string;
storage: ByokStorage;
} | null>(null);
const load = useCallback(async () => {
if (!workspaceServer.server) {
return;
}
const to = new Date();
const from = new Date(to.getTime() - 30 * 24 * 60 * 60 * 1000);
const gql = workspaceServer.server.gql as GqlFn;
const data = await gql({
query: byokSettingsQuery,
variables: {
id: workspace.id,
from: from.toISOString(),
to: to.toISOString(),
},
});
const [localStorageSupported, nextLocalKeys] = await Promise.all([
localByokStorageSupported(),
readLocalKeys(workspace.id),
]);
setSettings({
...data.workspace.byokSettings,
localStorageSupported:
data.workspace.byokSettings.localEntitled && localStorageSupported,
});
setUsage(data.workspace.byokUsage);
setLocalKeys(nextLocalKeys);
}, [workspace.id, workspaceServer.server]);
useEffect(() => {
load().catch(error => {
logByokError('Failed to load BYOK settings', error);
notify.error({
title: byokT(t, 'notify.load-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}, [load, t]);
const keys = useMemo(() => {
return [...localKeys, ...(settings?.keys ?? [])].toSorted((a, b) => {
if (a.storage !== b.storage) {
return a.storage === ByokKeyStorage.local ? -1 : 1;
}
return a.sortOrder - b.sortOrder;
});
}, [localKeys, settings?.keys]);
const canAddServerKey = settings?.serverEntitled ?? false;
const canAddLocalKey =
(settings?.localEntitled ?? false) &&
(settings?.localStorageSupported ?? false);
const canManageKeys = canAddServerKey || canAddLocalKey;
const clearAll = useCallback(async () => {
if (!settings) {
return;
}
if (!workspaceServer.server && settings.serverEntitled) {
return;
}
if (settings.serverEntitled && workspaceServer.server) {
const gql = workspaceServer.server.gql as GqlFn;
await gql({
query: clearByokMutation,
variables: { workspaceId: workspace.id },
});
}
if (settings.localStorageSupported) {
await clearLocalKeys(workspace.id);
}
setLocalKeys([]);
await load();
}, [load, settings, workspace.id, workspaceServer.server]);
const deleteKey = useCallback(
async (key: ByokKey) => {
if (key.storage === ByokKeyStorage.local) {
await deleteLocalKey(workspace.id, key.id);
setLocalKeys(await readLocalKeys(workspace.id));
return;
}
const gql = workspaceServer.server?.gql as
| ((input: {
query: GraphQLQuery;
variables?: Record<string, unknown>;
}) => Promise<unknown>)
| undefined;
await gql?.({
query: deleteByokMutation,
variables: { workspaceId: workspace.id, id: key.id },
});
await load();
},
[load, workspace.id, workspaceServer.server]
);
const reorderKey = useCallback(
async (targetKey: ByokKey) => {
if (!draggingKey || draggingKey.id === targetKey.id) {
return;
}
if (draggingKey.storage !== targetKey.storage) {
notify.error({
title: byokT(t, 'notify.cross-storage-reorder.title'),
message: byokT(t, 'notify.cross-storage-reorder.message'),
});
return;
}
const bucket = keys.filter(key => key.storage === targetKey.storage);
const fromIndex = bucket.findIndex(key => key.id === draggingKey.id);
const toIndex = bucket.findIndex(key => key.id === targetKey.id);
if (fromIndex === -1 || toIndex === -1) {
return;
}
const nextBucket = [...bucket];
const [moved] = nextBucket.splice(fromIndex, 1);
nextBucket.splice(toIndex, 0, moved);
const nextBucketIds = nextBucket.map(key => key.id);
if (targetKey.storage === ByokKeyStorage.local) {
setLocalKeys(await reorderLocalKeys(workspace.id, nextBucketIds));
return;
}
const gql = workspaceServer.server?.gql as
| ((input: {
query: GraphQLQuery;
variables?: Record<string, unknown>;
}) => Promise<unknown>)
| undefined;
await gql?.({
query: reorderByokMutation,
variables: {
input: {
workspaceId: workspace.id,
storage: ByokKeyStorage.server,
ids: nextBucketIds,
},
},
});
await load();
},
[draggingKey, keys, load, t, workspace.id, workspaceServer.server]
);
if (!settings) {
return (
<SettingHeader
title={byokT(t, 'title-beta')}
subtitle={byokT(t, 'loading')}
/>
);
}
if (!settings.entitled) {
return (
<>
<SettingHeader
title={byokT(t, 'title-beta')}
subtitle={byokT(t, 'subtitle')}
/>
<SettingWrapper>
<div className={styles.locked} data-testid="workspace-byok-locked">
<div>
<div className={styles.title}>{byokT(t, 'locked.title')}</div>
<div className={styles.description}>
{byokT(t, 'locked.description')}
</div>
</div>
<div className={styles.tags}>
{settings.entitlementRequired.map(plan => (
<span className={styles.tag} key={plan}>
{plan}
</span>
))}
</div>
</div>
</SettingWrapper>
</>
);
}
return (
<>
<SettingHeader
title={byokT(t, 'title-beta')}
subtitle={byokT(t, 'header')}
/>
<SettingWrapper>
<div className={styles.stack}>
{settings.hasAiPlan ? (
<div className={styles.notice}>
<div className={styles.title}>{byokT(t, 'notice.title')}</div>
<div className={styles.description}>
{byokT(t, 'notice.description')}
</div>
</div>
) : null}
<div className={styles.panel} data-testid="workspace-byok-keys">
<div className={styles.panelHeader}>
<div>
<div className={styles.title}>{byokT(t, 'keys.title')}</div>
<div className={styles.description}>
{byokT(t, 'keys.description')}
</div>
</div>
<Button
variant="primary"
disabled={!canManageKeys}
onClick={() => {
setEditingKey(null);
setModalOpen(true);
}}
>
{byokT(t, 'action.add-key')}
</Button>
</div>
{keys.length ? (
<KeyList
keys={keys}
onEdit={key => {
setEditingKey(key);
setModalOpen(true);
}}
onDelete={key => {
deleteKey(key).catch(error => {
logByokError('Failed to delete BYOK key', error);
notify.error({
title: byokT(t, 'notify.delete-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}}
onDragStart={key => {
setDraggingKey({ id: key.id, storage: key.storage });
}}
onDragEnd={() => setDraggingKey(null)}
onDrop={key => {
reorderKey(key).catch(error => {
logByokError('Failed to reorder BYOK keys', error);
notify.error({
title: byokT(t, 'notify.reorder-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}}
/>
) : (
<div className={styles.empty} data-testid="workspace-byok-empty">
<div className={styles.title}>{byokT(t, 'empty.title')}</div>
<div className={styles.description}>
{byokT(t, 'empty.description')}
</div>
</div>
)}
</div>
<CoveragePanel keys={keys} settings={settings} />
<UsagePanel
keys={keys}
usage={usage}
onClearAll={() => {
clearAll().catch(error => {
logByokError('Failed to clear BYOK keys', error);
notify.error({
title: byokT(t, 'notify.clear-failed.title'),
message: byokT(t, 'notify.operation-failed.message'),
});
});
}}
/>
</div>
</SettingWrapper>
<AddKeyModal
workspaceId={workspace.id}
settings={settings}
editingKey={editingKey}
open={modalOpen}
onOpenChange={open => {
setModalOpen(open);
if (!open) {
setEditingKey(null);
}
}}
onSaved={load}
localKeys={localKeys}
setLocalKeys={setLocalKeys}
localStorageSupported={settings.localStorageSupported}
canAddServerKey={canAddServerKey}
canAddLocalKey={canAddLocalKey}
gql={workspaceServer.server?.gql as GqlFn | undefined}
/>
</>
);
};

View File

@@ -0,0 +1,92 @@
import { DragHandle, IconButton } from '@affine/component';
import { useI18n } from '@affine/i18n';
import { DeleteIcon, EditIcon } from '@blocksuite/icons/rc';
import type { DragEvent } from 'react';
import * as styles from './index.css';
import {
byokT,
capabilityLabel,
providerLabels,
rowDescription,
storageLabel,
} from './metadata';
import type { ByokKey } from './types';
export const KeyList = ({
keys,
onEdit,
onDelete,
onDragStart,
onDragEnd,
onDrop,
}: {
keys: ByokKey[];
onEdit: (key: ByokKey) => void;
onDelete: (key: ByokKey) => void;
onDragStart: (key: ByokKey) => void;
onDragEnd: () => void;
onDrop: (key: ByokKey) => void;
}) => {
const t = useI18n();
return (
<div className={styles.rows}>
{keys.map(key => (
<div
className={`${styles.row} ${key.enabled ? '' : styles.rowDisabled}`}
draggable
key={`${key.storage}:${key.id}`}
onDragStart={() => onDragStart(key)}
onDragEnd={onDragEnd}
onDragOver={(event: DragEvent<HTMLDivElement>) => {
event.preventDefault();
}}
onDrop={event => {
event.preventDefault();
onDrop(key);
}}
>
<div className={styles.dragHandle} title={byokT(t, 'action.reorder')}>
<DragHandle />
</div>
<div className={styles.rowMain}>
<div className={styles.rowTitle}>
{providerLabels[key.provider]} / {key.name}
<span className={styles.tag}>{storageLabel(t, key.storage)}</span>
{!key.enabled ? (
<span className={`${styles.tag} ${styles.dangerTag}`}>
{byokT(t, 'status.disabled-after-failure')}
</span>
) : null}
</div>
<div className={styles.rowDescription}>
{rowDescription(t, key)}
</div>
<div className={styles.tags}>
{key.capabilities.map(capability => (
<span className={styles.tag} key={capability}>
{capabilityLabel(t, capability)}
</span>
))}
</div>
</div>
<div className={styles.rowActions}>
<IconButton
size="20"
title={byokT(t, 'action.edit')}
icon={<EditIcon />}
onClick={() => onEdit(key)}
/>
<IconButton
size="20"
title={byokT(t, 'action.delete')}
icon={<DeleteIcon />}
onClick={() => onDelete(key)}
/>
</div>
</div>
))}
</div>
);
};

View File

@@ -0,0 +1,108 @@
import { apis } from '@affine/electron-api';
import { ByokKeyStorage, ByokKeyTestStatus } from '@affine/graphql';
import { capabilitiesFor } from './metadata';
import type { ByokKey, LocalByokKeyInput, LocalByokPublicKey } from './types';
function byokStorageApi() {
return BUILD_CONFIG.isElectron ? apis?.byokStorage : undefined;
}
export async function localByokStorageSupported() {
const storage = byokStorageApi();
if (!storage) {
return false;
}
try {
return await storage.isSupported();
} catch {
return false;
}
}
function toLocalByokKey(key: LocalByokPublicKey): ByokKey {
return {
id: key.id,
provider: key.provider,
name: key.name,
description: key.description ?? null,
storage: ByokKeyStorage.local,
configured: key.configured ?? true,
enabled: key.enabled ?? true,
endpoint: key.endpoint ?? null,
endpointEditable: key.endpointEditable ?? false,
sortOrder: key.sortOrder ?? 0,
capabilities: capabilitiesFor(key.provider, ByokKeyStorage.local),
testStatus: key.testStatus ?? ByokKeyTestStatus.passed,
};
}
export async function readLocalKeys(workspaceId: string): Promise<ByokKey[]> {
const storage = byokStorageApi();
if (!(await localByokStorageSupported()) || !storage) {
return [];
}
try {
const keys = (await storage.listWorkspaceKeys(
workspaceId
)) as LocalByokPublicKey[];
return keys.map(toLocalByokKey);
} catch {
return [];
}
}
export async function upsertLocalKey(
workspaceId: string,
key: LocalByokKeyInput
) {
const storage = byokStorageApi();
if (!(await localByokStorageSupported()) || !storage) {
return null;
}
try {
return await storage.upsertWorkspaceKey(workspaceId, key);
} catch {
return null;
}
}
export async function deleteLocalKey(workspaceId: string, keyId: string) {
const storage = byokStorageApi();
if (!(await localByokStorageSupported()) || !storage) {
return false;
}
try {
return await storage.deleteWorkspaceKey(workspaceId, keyId);
} catch {
return false;
}
}
export async function reorderLocalKeys(workspaceId: string, ids: string[]) {
const storage = byokStorageApi();
if (!(await localByokStorageSupported()) || !storage) {
return [];
}
try {
const keys = (await storage.reorderWorkspaceKeys(
workspaceId,
ids
)) as LocalByokPublicKey[];
return keys.map(toLocalByokKey);
} catch {
return [];
}
}
export async function clearLocalKeys(workspaceId: string) {
const storage = byokStorageApi();
if (!(await localByokStorageSupported()) || !storage) {
return false;
}
try {
return await storage.clearWorkspaceKeys(workspaceId);
} catch {
return false;
}
}

Some files were not shown because too many files have changed in this diff Show More