diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md index 448c65cc1e..9c1693aefe 100644 --- a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md +++ b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.md @@ -69,10 +69,10 @@ Generated by [AVA](https://avajs.dev). [ { - actions: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', + actions: 'generate text to text', status: 'claimed', - summary: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', - title: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', + summary: 'generate text to text', + title: 'generate text to text', transcription: [ { end: '00:00:45', @@ -102,10 +102,10 @@ Generated by [AVA](https://avajs.dev). [ { - actions: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', + actions: 'generate text to text', status: 'claimed', - summary: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', - title: '[{"a":"A","s":30,"e":45,"t":"Hello, everyone."},{"a":"B","s":46,"e":70,"t":"Hi, thank you for joining the meeting today."}]', + summary: 'generate text to text', + title: 'generate text to text', transcription: [ { end: '00:00:45', diff --git a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap index 53e83ab698..3bd449d93d 100644 Binary files a/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap and b/packages/backend/server/src/__tests__/__snapshots__/copilot.e2e.ts.snap differ diff --git a/packages/backend/server/src/__tests__/copilot-provider.spec.ts b/packages/backend/server/src/__tests__/copilot-provider.spec.ts index 3281c64f40..6ca04e26d7 100644 --- a/packages/backend/server/src/__tests__/copilot-provider.spec.ts +++ b/packages/backend/server/src/__tests__/copilot-provider.spec.ts @@ -6,7 +6,10 @@ import { AuthService } from '../core/auth'; import { QuotaModule } from '../core/quota'; import { CopilotModule } from '../plugins/copilot'; import { prompts, PromptService } from '../plugins/copilot/prompt'; -import { CopilotProviderFactory } from '../plugins/copilot/providers'; +import { + CopilotProviderFactory, + CopilotProviderType, +} from '../plugins/copilot/providers'; import { TranscriptionResponseSchema } from '../plugins/copilot/transcript/types'; import { CopilotChatTextExecutor, @@ -183,11 +186,18 @@ const checkUrl = (url: string) => { const retry = async ( action: string, t: ExecutionContext, - callback: (t: ExecutionContext) => void + callback: (t: ExecutionContext) => Promise ) => { let i = 3; while (i--) { - const ret = await t.try(callback); + const ret = await t.try(async t => { + try { + await callback(t); + } catch (e) { + t.log(`Error during ${action}:`, e); + throw e; + } + }); if (ret.passed) { return ret.commit(); } else { @@ -343,6 +353,7 @@ const actions = [ }); }, type: 'structured' as const, + prefer: CopilotProviderType.Gemini, }, { name: 'Should transcribe middle audio', @@ -365,6 +376,7 @@ const actions = [ }); }, type: 'structured' as const, + prefer: CopilotProviderType.Gemini, }, { name: 'Should transcribe long audio', @@ -387,6 +399,7 @@ const actions = [ }); }, type: 'structured' as const, + prefer: CopilotProviderType.Gemini, }, { promptName: [ @@ -554,7 +567,15 @@ const actions = [ }, ]; -for (const { name, promptName, messages, verifier, type, config } of actions) { +for (const { + name, + promptName, + messages, + verifier, + type, + config, + prefer, +} of actions) { const prompts = Array.isArray(promptName) ? promptName : [promptName]; for (const promptName of prompts) { test( @@ -564,7 +585,9 @@ for (const { name, promptName, messages, verifier, type, config } of actions) { const { factory, prompt: promptService } = t.context; const prompt = (await promptService.get(promptName))!; t.truthy(prompt, 'should have prompt'); - const provider = (await factory.getProviderByModel(prompt.model))!; + const provider = (await factory.getProviderByModel(prompt.model, { + prefer, + }))!; t.truthy(provider, 'should have provider'); await retry(`action: ${promptName}`, t, async t => { switch (type) { diff --git a/packages/backend/server/src/__tests__/copilot.e2e.ts b/packages/backend/server/src/__tests__/copilot.e2e.ts index a526862782..68cb33541a 100644 --- a/packages/backend/server/src/__tests__/copilot.e2e.ts +++ b/packages/backend/server/src/__tests__/copilot.e2e.ts @@ -19,6 +19,7 @@ import { MockEmbeddingClient } from '../plugins/copilot/context/embedding'; import { prompts, PromptService } from '../plugins/copilot/prompt'; import { CopilotProviderFactory, + CopilotProviderType, GeminiGenerativeProvider, OpenAIProvider, } from '../plugins/copilot/providers'; @@ -79,7 +80,7 @@ test.before(async t => { providers: { openai: { apiKey: '1' }, fal: {}, - perplexity: {}, + gemini: { apiKey: '1' }, }, unsplash: { key: process.env.UNSPLASH_ACCESS_KEY || '1', @@ -101,7 +102,10 @@ test.before(async t => { }); m.overrideProvider(OpenAIProvider).useClass(MockCopilotProvider); m.overrideProvider(GeminiGenerativeProvider).useClass( - MockCopilotProvider + class MockGenerativeProvider extends MockCopilotProvider { + // @ts-expect-error + override type: CopilotProviderType = CopilotProviderType.Gemini; + } ); }, }); diff --git a/packages/backend/server/src/plugins/copilot/transcript/service.ts b/packages/backend/server/src/plugins/copilot/transcript/service.ts index 67ff4f8090..885bfac5e1 100644 --- a/packages/backend/server/src/plugins/copilot/transcript/service.ts +++ b/packages/backend/server/src/plugins/copilot/transcript/service.ts @@ -18,6 +18,7 @@ import { PromptService } from '../prompt'; import { CopilotProvider, CopilotProviderFactory, + CopilotProviderType, ModelOutputType, PromptMessage, } from '../providers'; @@ -156,14 +157,18 @@ export class CopilotTranscriptionService { private async getProvider( modelId: string, - structured: boolean + structured: boolean, + prefer?: CopilotProviderType ): Promise { - let provider = await this.providerFactory.getProvider({ - outputType: structured - ? ModelOutputType.Structured - : ModelOutputType.Text, - modelId, - }); + let provider = await this.providerFactory.getProvider( + { + outputType: structured + ? ModelOutputType.Structured + : ModelOutputType.Text, + modelId, + }, + { prefer } + ); if (!provider) { throw new NoCopilotProviderAvailable(); @@ -175,7 +180,8 @@ export class CopilotTranscriptionService { private async chatWithPrompt( promptName: string, message: Partial, - schema?: ZodType + schema?: ZodType, + prefer?: CopilotProviderType ): Promise { const prompt = await this.prompt.get(promptName); if (!prompt) { @@ -186,7 +192,7 @@ export class CopilotTranscriptionService { const msg = { role: 'user' as const, content: '', ...message }; const config = Object.assign({}, prompt.config); if (schema) { - const provider = await this.getProvider(prompt.model, true); + const provider = await this.getProvider(prompt.model, true, prefer); return provider.structure( cond, [...prompt.finish({ schema }), msg], @@ -226,13 +232,12 @@ export class CopilotTranscriptionService { } private async callTranscript(url: string, mimeType: string, offset: number) { + // NOTE: Vertex provider not support transcription yet, we always use Gemini here const result = await this.chatWithPrompt( 'Transcript audio', - { - attachments: [url], - params: { mimetype: mimeType }, - }, - TranscriptionResponseSchema + { attachments: [url], params: { mimetype: mimeType } }, + TranscriptionResponseSchema, + CopilotProviderType.Gemini ); const transcription = TranscriptionResponseSchema.parse(