diff --git a/.docker/selfhost/schema.json b/.docker/selfhost/schema.json index a5f1b19628..82065c8295 100644 --- a/.docker/selfhost/schema.json +++ b/.docker/selfhost/schema.json @@ -643,6 +643,41 @@ "apiKey": "" } }, + "providers.geminiVertex": { + "type": "object", + "description": "The config for the google vertex provider.\n@default {}", + "properties": { + "location": { + "type": "string", + "description": "The location of the google vertex provider." + }, + "project": { + "type": "string", + "description": "The project name of the google vertex provider." + }, + "googleAuthOptions": { + "type": "object", + "description": "The google auth options for the google vertex provider.", + "properties": { + "credentials": { + "type": "object", + "description": "The credentials for the google vertex provider.", + "properties": { + "client_email": { + "type": "string", + "description": "The client email for the google vertex provider." + }, + "private_key": { + "type": "string", + "description": "The private key for the google vertex provider." + } + } + } + } + } + }, + "default": {} + }, "providers.perplexity": { "type": "object", "description": "The config for the perplexity provider.\n@default {\"apiKey\":\"\"}", @@ -657,6 +692,41 @@ "apiKey": "" } }, + "providers.anthropicVertex": { + "type": "object", + "description": "The config for the google vertex provider.\n@default {}", + "properties": { + "location": { + "type": "string", + "description": "The location of the google vertex provider." + }, + "project": { + "type": "string", + "description": "The project name of the google vertex provider." + }, + "googleAuthOptions": { + "type": "object", + "description": "The google auth options for the google vertex provider.", + "properties": { + "credentials": { + "type": "object", + "description": "The credentials for the google vertex provider.", + "properties": { + "client_email": { + "type": "string", + "description": "The client email for the google vertex provider." + }, + "private_key": { + "type": "string", + "description": "The private key for the google vertex provider." + } + } + } + } + } + }, + "default": {} + }, "unsplash": { "type": "object", "description": "The config for the unsplash key.\n@default {\"key\":\"\"}", diff --git a/.github/actions/server-test-env/action.yml b/.github/actions/server-test-env/action.yml index b54ed8ba1f..789465e65a 100644 --- a/.github/actions/server-test-env/action.yml +++ b/.github/actions/server-test-env/action.yml @@ -29,11 +29,7 @@ runs: - name: Import config shell: bash + env: + DEFAULT_CONFIG: '{}' run: | - printf '{"copilot":{"enabled":true,"providers.fal":{"apiKey":"%s"},"providers.gemini":{"apiKey":"%s"},"providers.openai":{"apiKey":"%s"},"providers.perplexity":{"apiKey":"%s"},"providers.anthropic":{"apiKey":"%s"},"exa":{"key":"%s"}}}' \ - "$COPILOT_FAL_API_KEY" \ - "$COPILOT_GOOGLE_API_KEY" \ - "$COPILOT_OPENAI_API_KEY" \ - "$COPILOT_PERPLEXITY_API_KEY" \ - "$COPILOT_ANTHROPIC_API_KEY" \ - "$COPILOT_EXA_API_KEY" > ./packages/backend/server/config.json + printf '%s\n' "${SERVER_CONFIG:-$DEFAULT_CONFIG}" > ./packages/backend/server/config.json diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 2ce368afea..37d2744d1c 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -1001,12 +1001,7 @@ jobs: - name: Prepare Server Test Environment if: ${{ steps.check-blocksuite-update.outputs.skip != 'true' || steps.apifilter.outputs.changed == 'true' }} env: - COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} - COPILOT_GOOGLE_API_KEY: ${{ secrets.COPILOT_GOOGLE_API_KEY }} - COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} - COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - COPILOT_ANTHROPIC_API_KEY: ${{ secrets.COPILOT_ANTHROPIC_API_KEY }} - COPILOT_EXA_API_KEY: ${{ secrets.COPILOT_EXA_API_KEY }} + SERVER_CONFIG: ${{ secrets.TEST_SERVER_CONFIG }} uses: ./.github/actions/server-test-env - name: Run server tests @@ -1105,12 +1100,7 @@ jobs: - name: Prepare Server Test Environment if: ${{ steps.check-blocksuite-update.outputs.skip != 'true' || steps.e2efilter.outputs.changed == 'true' }} env: - COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} - COPILOT_GOOGLE_API_KEY: ${{ secrets.COPILOT_GOOGLE_API_KEY }} - COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} - COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - COPILOT_ANTHROPIC_API_KEY: ${{ secrets.COPILOT_ANTHROPIC_API_KEY }} - COPILOT_EXA_API_KEY: ${{ secrets.COPILOT_EXA_API_KEY }} + SERVER_CONFIG: ${{ secrets.TEST_SERVER_CONFIG }} uses: ./.github/actions/server-test-env - name: Run Copilot E2E Test ${{ matrix.shardIndex }}/${{ matrix.shardTotal }} diff --git a/.github/workflows/copilot-test.yml b/.github/workflows/copilot-test.yml index 164470e0c4..84a152aeab 100644 --- a/.github/workflows/copilot-test.yml +++ b/.github/workflows/copilot-test.yml @@ -81,12 +81,7 @@ jobs: - name: Prepare Server Test Environment env: - COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} - COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} - COPILOT_GOOGLE_API_KEY: ${{ secrets.COPILOT_GOOGLE_API_KEY }} - COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - COPILOT_ANTHROPIC_API_KEY: ${{ secrets.COPILOT_ANTHROPIC_API_KEY }} - COPILOT_EXA_API_KEY: ${{ secrets.COPILOT_EXA_API_KEY }} + SERVER_CONFIG: ${{ secrets.TEST_SERVER_CONFIG }} uses: ./.github/actions/server-test-env - name: Run server tests @@ -156,12 +151,7 @@ jobs: - name: Prepare Server Test Environment env: - COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} - COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} - COPILOT_GOOGLE_API_KEY: ${{ secrets.COPILOT_GOOGLE_API_KEY }} - COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - COPILOT_ANTHROPIC_API_KEY: ${{ secrets.COPILOT_ANTHROPIC_API_KEY }} - COPILOT_EXA_API_KEY: ${{ secrets.COPILOT_EXA_API_KEY }} + SERVER_CONFIG: ${{ secrets.TEST_SERVER_CONFIG }} uses: ./.github/actions/server-test-env - name: Run Copilot E2E Test ${{ matrix.shardIndex }}/${{ matrix.shardTotal }} diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index c784d24cb5..58e5095864 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -30,6 +30,7 @@ "@affine/server-native": "workspace:*", "@ai-sdk/anthropic": "^1.2.10", "@ai-sdk/google": "^1.2.18", + "@ai-sdk/google-vertex": "^2.2.22", "@ai-sdk/openai": "^1.3.21", "@ai-sdk/perplexity": "^1.1.6", "@apollo/server": "^4.11.3", diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot.e2e.ts index 44f35bf236..a526862782 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot.e2e.ts @@ -19,7 +19,7 @@ import { MockEmbeddingClient } from '../plugins/copilot/context/embedding'; import { prompts, PromptService } from '../plugins/copilot/prompt'; import { CopilotProviderFactory, - GeminiProvider, + GeminiGenerativeProvider, OpenAIProvider, } from '../plugins/copilot/providers'; import { CopilotStorage } from '../plugins/copilot/storage'; @@ -100,7 +100,9 @@ test.before(async t => { }, }); m.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider); - m.overrideProvider(GeminiProvider).useClass(MockCopilotProvider); + m.overrideProvider(GeminiGenerativeProvider).useClass( + MockCopilotProvider + ); }, }); @@ -935,8 +937,8 @@ test('should be able to transcript', async t => { const { id: workspaceId } = await createWorkspace(app); for (const [provider, func] of [ - [GeminiProvider, 'text'], - [GeminiProvider, 'structure'], + [GeminiGenerativeProvider, 'text'], + [GeminiGenerativeProvider, 'structure'], ] as const) { Sinon.stub(app.get(provider), func).resolves( JSON.stringify([ diff --git a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts index 2f30df562f..52bf3a5d2d 100644 --- a/packages/backend/server/src/__tests__/mocks/copilot.mock.ts +++ b/packages/backend/server/src/__tests__/mocks/copilot.mock.ts @@ -111,7 +111,7 @@ export class MockCopilotProvider extends OpenAIProvider { ], }, { - id: 'gemini-2.5-pro-preview-05-06', + id: 'gemini-2.5-flash-preview-05-20', capabilities: [ { input: [ModelInputType.Text, ModelInputType.Image], diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index fb3f3c80ec..517d76c582 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -3,12 +3,15 @@ import { StorageJSONSchema, StorageProviderConfig, } from '../../base'; -import { AnthropicConfig } from './providers/anthropic'; +import { + AnthropicOfficialConfig, + AnthropicVertexConfig, +} from './providers/anthropic'; import type { FalConfig } from './providers/fal'; -import { GeminiConfig } from './providers/gemini'; +import { GeminiGenerativeConfig, GeminiVertexConfig } from './providers/gemini'; import { OpenAIConfig } from './providers/openai'; import { PerplexityConfig } from './providers/perplexity'; - +import { VertexSchema } from './providers/types'; declare global { interface AppConfigSchema { copilot: { @@ -23,9 +26,11 @@ declare global { providers: { openai: ConfigItem; fal: ConfigItem; - gemini: ConfigItem; + gemini: ConfigItem; + geminiVertex: ConfigItem; perplexity: ConfigItem; - anthropic: ConfigItem; + anthropic: ConfigItem; + anthropicVertex: ConfigItem; }; }; } @@ -55,6 +60,11 @@ defineModuleConfig('copilot', { apiKey: '', }, }, + 'providers.geminiVertex': { + desc: 'The config for the gemini provider in Google Vertex AI.', + default: {}, + schema: VertexSchema, + }, 'providers.perplexity': { desc: 'The config for the perplexity provider.', default: { @@ -67,6 +77,11 @@ defineModuleConfig('copilot', { apiKey: '', }, }, + 'providers.anthropicVertex': { + desc: 'The config for the anthropic provider in Google Vertex AI.', + default: {}, + schema: VertexSchema, + }, unsplash: { desc: 'The config for the unsplash key.', default: { diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index 3b334d8643..f5239ecbb3 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -350,7 +350,7 @@ const actions: Prompt[] = [ { name: 'Transcript audio', action: 'Transcript audio', - model: 'gemini-2.5-pro-preview-05-06', + model: 'gemini-2.5-flash-preview-05-20', messages: [ { role: 'system', @@ -1096,8 +1096,12 @@ const chat: Prompt[] = [ 'o4-mini', 'claude-3-7-sonnet-20250219', 'claude-3-5-sonnet-20241022', - 'gemini-2.5-flash-preview-04-17', + 'gemini-2.5-flash-preview-05-20', 'gemini-2.5-pro-preview-05-06', + 'claude-opus-4@20250514', + 'claude-sonnet-4@20250514', + 'claude-3-7-sonnet@20250219', + 'claude-3-5-sonnet@20240620', ], messages: [ { diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts similarity index 80% rename from packages/backend/server/src/plugins/copilot/providers/anthropic.ts rename to packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts index a94f440159..f979094c66 100644 --- a/packages/backend/server/src/plugins/copilot/providers/anthropic.ts +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/anthropic.ts @@ -1,71 +1,33 @@ import { - AnthropicProvider as AnthropicSDKProvider, - AnthropicProviderOptions, - createAnthropic, + type AnthropicProvider as AnthropicSDKProvider, + type AnthropicProviderOptions, } from '@ai-sdk/anthropic'; +import { type GoogleVertexAnthropicProvider } from '@ai-sdk/google-vertex/anthropic'; import { AISDKError, generateText, streamText } from 'ai'; import { CopilotProviderSideError, metrics, UserFriendlyError, -} from '../../../base'; -import { createExaCrawlTool, createExaSearchTool } from '../tools'; -import { CopilotProvider } from './provider'; +} from '../../../../base'; +import { createExaCrawlTool, createExaSearchTool } from '../../tools'; +import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, ModelConditions, PromptMessage, -} from './types'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; -import { chatToGPTMessage } from './utils'; - -export type AnthropicConfig = { - apiKey: string; - baseUrl?: string; -}; - -export class AnthropicProvider extends CopilotProvider { - override readonly type = CopilotProviderType.Anthropic; - override readonly models = [ - { - id: 'claude-3-7-sonnet-20250219', - capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], - defaultForOutputType: true, - }, - ], - }, - { - id: 'claude-3-5-sonnet-20241022', - capabilities: [ - { - input: [ModelInputType.Text, ModelInputType.Image], - output: [ModelOutputType.Text], - }, - ], - }, - ]; +} from '../types'; +import { ModelOutputType } from '../types'; +import { chatToGPTMessage } from '../utils'; +export abstract class AnthropicProvider extends CopilotProvider { private readonly MAX_STEPS = 20; private readonly CALLOUT_PREFIX = '\n> [!]\n> '; - #instance!: AnthropicSDKProvider; - - override configured(): boolean { - return !!this.config.apiKey; - } - - protected override setup() { - super.setup(); - this.#instance = createAnthropic({ - apiKey: this.config.apiKey, - baseURL: this.config.baseUrl, - }); - } + protected abstract instance: + | AnthropicSDKProvider + | GoogleVertexAnthropicProvider; private handleError(e: any) { if (e instanceof UserFriendlyError) { @@ -100,7 +62,7 @@ export class AnthropicProvider extends CopilotProvider { const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = this.#instance(model.id); + const modelInstance = this.instance(model.id); const { text, reasoning } = await generateText({ model: modelInstance, system, @@ -136,7 +98,7 @@ export class AnthropicProvider extends CopilotProvider { metrics.ai.counter('chat_text_stream_calls').add(1, { model: model.id }); const [system, msgs] = await chatToGPTMessage(messages); const { fullStream } = streamText({ - model: this.#instance(model.id), + model: this.instance(model.id), system, messages: msgs, abortSignal: options.signal, diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/index.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/index.ts new file mode 100644 index 0000000000..f37327b774 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/index.ts @@ -0,0 +1,2 @@ +export * from './official'; +export * from './vertex'; diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts new file mode 100644 index 0000000000..936c2feb3d --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/official.ts @@ -0,0 +1,52 @@ +import { + type AnthropicProvider as AnthropicSDKProvider, + createAnthropic, +} from '@ai-sdk/anthropic'; + +import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { AnthropicProvider } from './anthropic'; + +export type AnthropicOfficialConfig = { + apiKey: string; + baseUrl?: string; +}; + +export class AnthropicOfficialProvider extends AnthropicProvider { + override readonly type = CopilotProviderType.Anthropic; + + override readonly models = [ + { + id: 'claude-3-7-sonnet-20250219', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + defaultForOutputType: true, + }, + ], + }, + { + id: 'claude-3-5-sonnet-20241022', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + }, + ], + }, + ]; + + protected instance!: AnthropicSDKProvider; + + override configured(): boolean { + return !!this.config.apiKey; + } + + override setup() { + super.setup(); + this.instance = createAnthropic({ + apiKey: this.config.apiKey, + baseURL: this.config.baseUrl, + }); + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts new file mode 100644 index 0000000000..355ab779a3 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/anthropic/vertex.ts @@ -0,0 +1,65 @@ +import { + createVertexAnthropic, + type GoogleVertexAnthropicProvider, + type GoogleVertexAnthropicProviderSettings, +} from '@ai-sdk/google-vertex/anthropic'; + +import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { AnthropicProvider } from './anthropic'; + +export type AnthropicVertexConfig = GoogleVertexAnthropicProviderSettings; + +export class AnthropicVertexProvider extends AnthropicProvider { + override readonly type = CopilotProviderType.AnthropicVertex; + + override readonly models = [ + { + id: 'claude-opus-4@20250514', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + }, + ], + }, + { + id: 'claude-sonnet-4@20250514', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + }, + ], + }, + { + id: 'claude-3-7-sonnet@20250219', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + }, + ], + }, + { + id: 'claude-3-5-sonnet@20240620', + capabilities: [ + { + input: [ModelInputType.Text, ModelInputType.Image], + output: [ModelOutputType.Text], + defaultForOutputType: true, + }, + ], + }, + ]; + + protected instance!: GoogleVertexAnthropicProvider; + + override configured(): boolean { + return !!this.config.location && !!this.config.googleAuthOptions; + } + + override setup() { + super.setup(); + this.instance = createVertexAnthropic(this.config); + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts similarity index 73% rename from packages/backend/server/src/plugins/copilot/providers/gemini.ts rename to packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts index 5604363418..42c78caaf6 100644 --- a/packages/backend/server/src/plugins/copilot/providers/gemini.ts +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/gemini.ts @@ -1,8 +1,8 @@ -import { - createGoogleGenerativeAI, - type GoogleGenerativeAIProvider, - type GoogleGenerativeAIProviderOptions, +import type { + GoogleGenerativeAIProvider, + GoogleGenerativeAIProviderOptions, } from '@ai-sdk/google'; +import type { GoogleVertexProvider } from '@ai-sdk/google-vertex'; import { AISDKError, generateObject, @@ -16,16 +16,16 @@ import { CopilotProviderSideError, metrics, UserFriendlyError, -} from '../../../base'; -import { CopilotProvider } from './provider'; +} from '../../../../base'; +import { CopilotProvider } from '../provider'; import type { CopilotChatOptions, CopilotImageOptions, ModelConditions, PromptMessage, -} from './types'; -import { CopilotProviderType, ModelInputType, ModelOutputType } from './types'; -import { chatToGPTMessage } from './utils'; +} from '../types'; +import { ModelOutputType } from '../types'; +import { chatToGPTMessage } from '../utils'; export const DEFAULT_DIMENSIONS = 256; @@ -34,82 +34,14 @@ export type GeminiConfig = { baseUrl?: string; }; -export class GeminiProvider extends CopilotProvider { - override readonly type = CopilotProviderType.Gemini; - - readonly models = [ - { - name: 'Gemini 2.0 Flash', - id: 'gemini-2.0-flash-001', - capabilities: [ - { - input: [ - ModelInputType.Text, - ModelInputType.Image, - ModelInputType.Audio, - ], - output: [ModelOutputType.Text, ModelOutputType.Structured], - defaultForOutputType: true, - }, - ], - }, - { - name: 'Gemini 2.5 Flash', - id: 'gemini-2.5-flash-preview-04-17', - capabilities: [ - { - input: [ - ModelInputType.Text, - ModelInputType.Image, - ModelInputType.Audio, - ], - output: [ModelOutputType.Text, ModelOutputType.Structured], - }, - ], - }, - { - name: 'Gemini 2.5 Pro', - id: 'gemini-2.5-pro-preview-05-06', - capabilities: [ - { - input: [ - ModelInputType.Text, - ModelInputType.Image, - ModelInputType.Audio, - ], - output: [ModelOutputType.Text, ModelOutputType.Structured], - }, - ], - }, - { - name: 'Text Embedding 004', - id: 'text-embedding-004', - capabilities: [ - { - input: [ModelInputType.Text], - output: [ModelOutputType.Embedding], - }, - ], - }, - ]; - +export abstract class GeminiProvider extends CopilotProvider { private readonly MAX_STEPS = 20; private readonly CALLOUT_PREFIX = '\n> [!]\n> '; - #instance!: GoogleGenerativeAIProvider; - - override configured(): boolean { - return !!this.config.apiKey; - } - - protected override setup() { - super.setup(); - this.#instance = createGoogleGenerativeAI({ - apiKey: this.config.apiKey, - baseURL: this.config.baseUrl, - }); - } + protected abstract instance: + | GoogleGenerativeAIProvider + | GoogleVertexProvider; private handleError(e: any) { if (e instanceof UserFriendlyError) { @@ -130,7 +62,7 @@ export class GeminiProvider extends CopilotProvider { } } - override async text( + async text( cond: ModelConditions, messages: PromptMessage[], options: CopilotChatOptions = {} @@ -144,7 +76,7 @@ export class GeminiProvider extends CopilotProvider { const [system, msgs] = await chatToGPTMessage(messages); - const modelInstance = this.#instance(model.id); + const modelInstance = this.instance(model.id); const { text } = await generateText({ model: modelInstance, system, @@ -177,7 +109,7 @@ export class GeminiProvider extends CopilotProvider { throw new CopilotPromptInvalid('Schema is required'); } - const modelInstance = this.#instance(model.id, { + const modelInstance = this.instance(model.id, { structuredOutputs: true, }); const { object } = await generateObject({ @@ -209,7 +141,7 @@ export class GeminiProvider extends CopilotProvider { } } - override async *streamText( + async *streamText( cond: ModelConditions, messages: PromptMessage[], options: CopilotChatOptions | CopilotImageOptions = {} @@ -223,7 +155,7 @@ export class GeminiProvider extends CopilotProvider { const [system, msgs] = await chatToGPTMessage(messages); const { fullStream } = streamText({ - model: this.#instance(model.id, { + model: this.instance(model.id, { useSearchGrounding: this.useSearchGrounding(options), }), system, diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts new file mode 100644 index 0000000000..26bde7f9fd --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/generative.ts @@ -0,0 +1,86 @@ +import { + createGoogleGenerativeAI, + type GoogleGenerativeAIProvider, +} from '@ai-sdk/google'; + +import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { GeminiProvider } from './gemini'; + +export type GeminiGenerativeConfig = { + apiKey: string; + baseUrl?: string; +}; + +export class GeminiGenerativeProvider extends GeminiProvider { + override readonly type = CopilotProviderType.Gemini; + + readonly models = [ + { + name: 'Gemini 2.0 Flash', + id: 'gemini-2.0-flash-001', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ModelOutputType.Text, ModelOutputType.Structured], + defaultForOutputType: true, + }, + ], + }, + { + name: 'Gemini 2.5 Flash', + id: 'gemini-2.5-flash-preview-05-20', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ModelOutputType.Text, ModelOutputType.Structured], + }, + ], + }, + { + name: 'Gemini 2.5 Pro', + id: 'gemini-2.5-pro-preview-05-06', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ModelOutputType.Text, ModelOutputType.Structured], + }, + ], + }, + { + name: 'Text Embedding 004', + id: 'text-embedding-004', + capabilities: [ + { + input: [ModelInputType.Text], + output: [ModelOutputType.Embedding], + }, + ], + }, + ]; + + protected instance!: GoogleGenerativeAIProvider; + + override configured(): boolean { + return !!this.config.apiKey; + } + + protected override setup() { + super.setup(); + this.instance = createGoogleGenerativeAI({ + apiKey: this.config.apiKey, + baseURL: this.config.baseUrl, + }); + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/index.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/index.ts new file mode 100644 index 0000000000..fbacc7fdc2 --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/index.ts @@ -0,0 +1,2 @@ +export * from './generative'; +export * from './vertex'; diff --git a/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts new file mode 100644 index 0000000000..06b89f8e6f --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/gemini/vertex.ts @@ -0,0 +1,56 @@ +import { + createVertex, + type GoogleVertexProvider, + type GoogleVertexProviderSettings, +} from '@ai-sdk/google-vertex'; + +import { CopilotProviderType, ModelInputType, ModelOutputType } from '../types'; +import { GeminiProvider } from './gemini'; + +export type GeminiVertexConfig = GoogleVertexProviderSettings; + +export class GeminiVertexProvider extends GeminiProvider { + override readonly type = CopilotProviderType.GeminiVertex; + + readonly models = [ + { + name: 'Gemini 2.5 Flash', + id: 'gemini-2.5-flash-preview-05-20', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ModelOutputType.Text, ModelOutputType.Structured], + }, + ], + }, + { + name: 'Gemini 2.5 Pro', + id: 'gemini-2.5-pro-preview-05-06', + capabilities: [ + { + input: [ + ModelInputType.Text, + ModelInputType.Image, + ModelInputType.Audio, + ], + output: [ModelOutputType.Text, ModelOutputType.Structured], + }, + ], + }, + ]; + + protected instance!: GoogleVertexProvider; + + override configured(): boolean { + return !!this.config.location && !!this.config.googleAuthOptions; + } + + protected override setup() { + super.setup(); + this.instance = createVertex(this.config); + } +} diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts index 942ce62c1b..fbaf1d5fe4 100644 --- a/packages/backend/server/src/plugins/copilot/providers/index.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -1,21 +1,29 @@ -import { AnthropicProvider } from './anthropic'; +import { + AnthropicOfficialProvider, + AnthropicVertexProvider, +} from './anthropic'; import { FalProvider } from './fal'; -import { GeminiProvider } from './gemini'; +import { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini'; import { OpenAIProvider } from './openai'; import { PerplexityProvider } from './perplexity'; export const CopilotProviders = [ OpenAIProvider, FalProvider, - GeminiProvider, + GeminiGenerativeProvider, + GeminiVertexProvider, PerplexityProvider, - AnthropicProvider, + AnthropicOfficialProvider, + AnthropicVertexProvider, ]; -export { AnthropicProvider } from './anthropic'; +export { + AnthropicOfficialProvider, + AnthropicVertexProvider, +} from './anthropic'; export { CopilotProviderFactory } from './factory'; export { FalProvider } from './fal'; -export { GeminiProvider } from './gemini'; +export { GeminiGenerativeProvider, GeminiVertexProvider } from './gemini'; export { OpenAIProvider } from './openai'; export { PerplexityProvider } from './perplexity'; export type { CopilotProvider } from './provider'; diff --git a/packages/backend/server/src/plugins/copilot/providers/types.ts b/packages/backend/server/src/plugins/copilot/providers/types.ts index e20f3c5d20..bb29e02001 100644 --- a/packages/backend/server/src/plugins/copilot/providers/types.ts +++ b/packages/backend/server/src/plugins/copilot/providers/types.ts @@ -1,12 +1,16 @@ import { AiPromptRole } from '@prisma/client'; import { z } from 'zod'; +import { JSONSchema } from '../../../base'; + // ========== provider ========== export enum CopilotProviderType { Anthropic = 'anthropic', + AnthropicVertex = 'anthropicVertex', FAL = 'fal', Gemini = 'gemini', + GeminiVertex = 'geminiVertex', OpenAI = 'openai', Perplexity = 'perplexity', } @@ -15,6 +19,41 @@ export const CopilotProviderSchema = z.object({ type: z.nativeEnum(CopilotProviderType), }); +export const VertexSchema: JSONSchema = { + type: 'object', + description: 'The config for the google vertex provider.', + properties: { + location: { + type: 'string', + description: 'The location of the google vertex provider.', + }, + project: { + type: 'string', + description: 'The project name of the google vertex provider.', + }, + googleAuthOptions: { + type: 'object', + description: 'The google auth options for the google vertex provider.', + properties: { + credentials: { + type: 'object', + description: 'The credentials for the google vertex provider.', + properties: { + client_email: { + type: 'string', + description: 'The client email for the google vertex provider.', + }, + private_key: { + type: 'string', + description: 'The private key for the google vertex provider.', + }, + }, + }, + }, + }, + }, +}; + // ========== prompt ========== export const PromptConfigStrictSchema = z.object({ diff --git a/packages/frontend/admin/src/config.json b/packages/frontend/admin/src/config.json index dc65f58a85..ca9dd02293 100644 --- a/packages/frontend/admin/src/config.json +++ b/packages/frontend/admin/src/config.json @@ -233,6 +233,10 @@ "type": "Object", "desc": "The config for the gemini provider." }, + "providers.geminiVertex": { + "type": "Object", + "desc": "The config for the gemini provider in Google Vertex AI." + }, "providers.perplexity": { "type": "Object", "desc": "The config for the perplexity provider." @@ -241,6 +245,10 @@ "type": "Object", "desc": "The config for the anthropic provider." }, + "providers.anthropicVertex": { + "type": "Object", + "desc": "The config for the anthropic provider in Google Vertex AI." + }, "unsplash": { "type": "Object", "desc": "The config for the unsplash key." diff --git a/yarn.lock b/yarn.lock index 2d04cdc353..3786152aa6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -909,6 +909,7 @@ __metadata: "@affine/server-native": "workspace:*" "@ai-sdk/anthropic": "npm:^1.2.10" "@ai-sdk/google": "npm:^1.2.18" + "@ai-sdk/google-vertex": "npm:^2.2.22" "@ai-sdk/openai": "npm:^1.3.21" "@ai-sdk/perplexity": "npm:^1.1.6" "@apollo/server": "npm:^4.11.3" @@ -1073,7 +1074,7 @@ __metadata: languageName: unknown linkType: soft -"@ai-sdk/anthropic@npm:^1.2.10": +"@ai-sdk/anthropic@npm:1.2.11, @ai-sdk/anthropic@npm:^1.2.10": version: 1.2.11 resolution: "@ai-sdk/anthropic@npm:1.2.11" dependencies: @@ -1085,7 +1086,22 @@ __metadata: languageName: node linkType: hard -"@ai-sdk/google@npm:^1.2.18": +"@ai-sdk/google-vertex@npm:^2.2.22": + version: 2.2.22 + resolution: "@ai-sdk/google-vertex@npm:2.2.22" + dependencies: + "@ai-sdk/anthropic": "npm:1.2.11" + "@ai-sdk/google": "npm:1.2.18" + "@ai-sdk/provider": "npm:1.1.3" + "@ai-sdk/provider-utils": "npm:2.2.8" + google-auth-library: "npm:^9.15.0" + peerDependencies: + zod: ^3.0.0 + checksum: 10/852928d35797cc14802d8c4a7dba2794197f019507a24124e7f87e456ba19f2b4e50438af626dd30e3a2a0f9f10819d6534a31e23c08f7f3d148d56b68167d0a + languageName: node + linkType: hard + +"@ai-sdk/google@npm:1.2.18, @ai-sdk/google@npm:^1.2.18": version: 1.2.18 resolution: "@ai-sdk/google@npm:1.2.18" dependencies: @@ -22796,7 +22812,7 @@ __metadata: languageName: node linkType: hard -"google-auth-library@npm:^9.0.0, google-auth-library@npm:^9.7.0": +"google-auth-library@npm:^9.0.0, google-auth-library@npm:^9.15.0, google-auth-library@npm:^9.7.0": version: 9.15.1 resolution: "google-auth-library@npm:9.15.1" dependencies: