feat(server): init gemini provider & transcript action (#10731)

This commit is contained in:
darkskygit
2025-03-18 04:28:18 +00:00
parent a016630a82
commit e09b5fee12
14 changed files with 557 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import type { ExecutionContext, TestFn } from 'ava';
import ava from 'ava';
import { z } from 'zod';
import { ConfigModule } from '../base/config';
import { AuthService } from '../core/auth';
@@ -269,6 +270,36 @@ test('should validate markdown list', t => {
// ==================== action ====================
const actions = [
{
promptName: ['Transcript audio'],
messages: [
{
role: 'user' as const,
content: '',
attachments: [
'https://cdn.affine.pro/copilot-test/MP9qDGuYgnY+ILoEAmHpp3h9Npuw2403EAYMEA.mp3',
],
},
],
verifier: (t: ExecutionContext<Tester>, result: string) => {
// cleanup json markdown wrap
const cleaned = result
.replace(/```[\w\s]+\n/g, '')
.replace(/\n```/g, '')
.trim();
t.notThrows(() => {
z.object({
speaker: z.string(),
start: z.string(),
end: z.string(),
transcription: z.string(),
})
.array()
.parse(JSON.parse(cleaned));
});
},
type: 'text' as const,
},
{
promptName: [
'Summary',
@@ -401,6 +432,7 @@ const actions = [
type: 'image' as const,
},
];
for (const { promptName, messages, verifier, type } of actions) {
const prompts = Array.isArray(promptName) ? promptName : [promptName];
for (const promptName of prompts) {

View File

@@ -28,6 +28,7 @@ AFFiNE.ENV_MAP = {
CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
COPILOT_GOOGLE_API_KEY: 'plugins.copilot.google.apiKey',
COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey',
COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey',
REDIS_SERVER_HOST: 'redis.host',

View File

@@ -3,11 +3,13 @@ import type { ClientOptions as OpenAIClientOptions } from 'openai';
import { defineStartupConfig, ModuleConfig } from '../../base/config';
import { StorageConfig } from '../../base/storage/config';
import type { FalConfig } from './providers/fal';
import { GoogleConfig } from './providers/google';
import { PerplexityConfig } from './providers/perplexity';
export interface CopilotStartupConfigurations {
openai?: OpenAIClientOptions;
fal?: FalConfig;
google?: GoogleConfig;
perplexity?: PerplexityConfig;
test?: never;
unsplashKey?: string;

View File

@@ -23,6 +23,7 @@ import {
PerplexityProvider,
registerCopilotProvider,
} from './providers';
import { GoogleProvider } from './providers/google';
import {
CopilotResolver,
PromptsManagementResolver,
@@ -34,6 +35,7 @@ import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
registerCopilotProvider(GoogleProvider);
registerCopilotProvider(PerplexityProvider);
@Plugin({

View File

@@ -317,6 +317,62 @@ const actions: Prompt[] = [
model: 'face-to-sticker',
messages: [],
},
{
name: 'Transcript audio',
action: 'Transcript audio',
model: 'gemini-2.0-flash-001',
messages: [
{
role: 'system',
content: `
Convert a multi-speaker audio recording into a structured JSON format by transcribing the speech and identifying individual speakers.
1. Analyze the audio to detect the presence of multiple speakers using distinct microphone inputs.
2. Transcribe the audio content for each speaker and note the time intervals of speech.
# Output Format
The output should be a JSON array, with each element containing:
- "speaker": A label identifying the speaker, such as "A", "B", etc.
- "start": The start time of the transcribed segment in the format "HH:MM:SS".
- "end": The end time of the transcribed segment in the format "HH:MM:SS".
- "transcription": The transcribed text for the speaker's segment.
# Examples
**Example Input:**
- A multi-speaker audio file
**Example Output:**
[
{
"speaker": "A",
"start": "00:00:30",
"end": "00:00:45",
"transcription": "Hello, everyone."
},
{
"speaker": "B",
"start": "00:00:46",
"end": "00:01:10",
"transcription": "Hi, thank you for joining the meeting today."
}
]
# Notes
- Ensure the accurate differentiation of speakers even if multiple speakers overlap slightly or switch rapidly.
- Maintain a consistent speaker labeling system throughout the transcription.
`,
},
],
config: {
audioTimestamp: true,
jsonMode: true,
},
},
{
name: 'Generate a caption',
action: 'Generate a caption',

View File

@@ -0,0 +1,292 @@
import {
createGoogleGenerativeAI,
type GoogleGenerativeAIProvider,
} from '@ai-sdk/google';
import { Logger } from '@nestjs/common';
import {
AISDKError,
type CoreAssistantMessage,
type CoreUserMessage,
FilePart,
generateText,
streamText,
TextPart,
} from 'ai';
import {
CopilotPromptInvalid,
CopilotProviderSideError,
metrics,
UserFriendlyError,
} from '../../../base';
import {
ChatMessageRole,
CopilotCapability,
CopilotChatOptions,
CopilotProviderType,
CopilotTextToTextProvider,
PromptMessage,
} from '../types';
export const DEFAULT_DIMENSIONS = 256;
const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;
const FORMAT_INFER_MAP: Record<string, string> = {
pdf: 'application/pdf',
mp3: 'audio/mpeg',
wav: 'audio/wav',
png: 'image/png',
jpeg: 'image/jpeg',
jpg: 'image/jpeg',
webp: 'image/webp',
txt: 'text/plain',
md: 'text/plain',
mov: 'video/mov',
mpeg: 'video/mpeg',
mp4: 'video/mp4',
avi: 'video/avi',
wmv: 'video/wmv',
flv: 'video/flv',
};
export type GoogleConfig = {
apiKey: string;
baseUrl?: string;
};
type ChatMessage = CoreUserMessage | CoreAssistantMessage;
export class GoogleProvider implements CopilotTextToTextProvider {
static readonly type = CopilotProviderType.Google;
static readonly capabilities = [CopilotCapability.TextToText];
readonly availableModels = [
// text to text
'gemini-2.0-flash-001',
// embeddings
'text-embedding-004',
];
private readonly logger = new Logger(GoogleProvider.name);
private readonly instance: GoogleGenerativeAIProvider;
constructor(config: GoogleConfig) {
this.instance = createGoogleGenerativeAI(config);
}
static assetsConfig(config: GoogleConfig) {
return !!config?.apiKey;
}
get type(): CopilotProviderType {
return GoogleProvider.type;
}
getCapabilities(): CopilotCapability[] {
return GoogleProvider.capabilities;
}
async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}
private inferMimeType(url: string) {
if (url.startsWith('data:')) {
return url.split(';')[0].split(':')[1];
}
const extension = url.split('.').pop();
if (extension) {
return FORMAT_INFER_MAP[extension];
}
return undefined;
}
protected chatToGPTMessage(
messages: PromptMessage[]
): [string | undefined, ChatMessage[]] {
let system =
messages[0]?.role === 'system' ? messages.shift()?.content : undefined;
// filter redundant fields
const msgs = messages
.filter(m => m.role !== 'system')
.map(({ role, content, attachments, params }) => {
content = content.trim();
role = role as 'user' | 'assistant';
const mimetype = params?.mimetype;
if (Array.isArray(attachments)) {
const contents: (TextPart | FilePart)[] = [];
if (content.length) {
contents.push({
type: 'text',
text: content,
});
}
contents.push(
...attachments
.map(url => {
if (SIMPLE_IMAGE_URL_REGEX.test(url)) {
const mimeType =
typeof mimetype === 'string'
? mimetype
: this.inferMimeType(url);
if (mimeType) {
const data = url.startsWith('data:') ? url : new URL(url);
return {
type: 'file' as const,
data,
mimeType,
};
}
}
return undefined;
})
.filter(c => !!c)
);
return { role, content: contents } as ChatMessage;
} else {
return { role, content } as ChatMessage;
}
});
return [system, msgs];
}
protected async checkParams({
messages,
embeddings,
model,
options = {},
}: {
messages?: PromptMessage[];
embeddings?: string[];
model: string;
options: CopilotChatOptions;
}) {
if (!(await this.isModelAvailable(model))) {
throw new CopilotPromptInvalid(`Invalid model: ${model}`);
}
if (Array.isArray(messages) && messages.length > 0) {
if (
messages.some(
m =>
// check non-object
typeof m !== 'object' ||
!m ||
// check content
typeof m.content !== 'string' ||
// content and attachments must exist at least one
((!m.content || !m.content.trim()) &&
(!Array.isArray(m.attachments) || !m.attachments.length))
)
) {
throw new CopilotPromptInvalid('Empty message content');
}
if (
messages.some(
m =>
typeof m.role !== 'string' ||
!m.role ||
!ChatMessageRole.includes(m.role)
)
) {
throw new CopilotPromptInvalid('Invalid message role');
}
// json mode need 'json' keyword in content
// ref: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
if (
options.jsonMode &&
!messages.some(m => m.content.toLowerCase().includes('json'))
) {
throw new CopilotPromptInvalid('Prompt not support json mode');
}
} else if (
Array.isArray(embeddings) &&
embeddings.some(e => typeof e !== 'string' || !e || !e.trim())
) {
throw new CopilotPromptInvalid('Invalid embedding');
}
}
private handleError(e: any) {
if (e instanceof UserFriendlyError) {
return e;
} else if (e instanceof AISDKError) {
this.logger.error('Throw error from ai sdk:', e);
return new CopilotProviderSideError({
provider: this.type,
kind: e.name || 'unknown',
message: e.message,
});
} else {
return new CopilotProviderSideError({
provider: this.type,
kind: 'unexpected_response',
message: e?.message || 'Unexpected google response',
});
}
}
// ====== text to text ======
async generateText(
messages: PromptMessage[],
model: string = 'gemini-2.0-flash-001',
options: CopilotChatOptions = {}
): Promise<string> {
await this.checkParams({ messages, model, options });
try {
metrics.ai.counter('chat_text_calls').add(1, { model });
const [system, msgs] = this.chatToGPTMessage(messages);
const { text } = await generateText({
model: this.instance(model, {
audioTimestamp: Boolean(options.audioTimestamp),
structuredOutputs: Boolean(options.jsonMode),
}),
system,
messages: msgs,
abortSignal: options.signal,
});
if (!text) throw new Error('Failed to generate text');
return text.trim();
} catch (e: any) {
metrics.ai.counter('chat_text_errors').add(1, { model });
throw this.handleError(e);
}
}
async *generateTextStream(
messages: PromptMessage[],
model: string = 'gpt-4o-mini',
options: CopilotChatOptions = {}
): AsyncIterable<string> {
await this.checkParams({ messages, model, options });
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
const [system, msgs] = this.chatToGPTMessage(messages);
const { textStream } = streamText({
model: this.instance(model),
system,
messages: msgs,
abortSignal: options.signal,
});
for await (const message of textStream) {
if (message) {
yield message;
if (options.signal?.aborted) {
await textStream.cancel();
break;
}
}
}
} catch (e: any) {
metrics.ai.counter('chat_text_stream_errors').add(1, { model });
throw this.handleError(e);
}
}
}

View File

@@ -77,6 +77,8 @@ export const PromptConfigStrictSchema = z.object({
)
.nullable()
.optional(),
// google
audioTimestamp: z.boolean().nullable().optional(),
});
export const PromptConfigSchema =
@@ -156,6 +158,7 @@ export type ListHistoriesOptions = {
export enum CopilotProviderType {
FAL = 'fal',
Google = 'google',
OpenAI = 'openai',
Perplexity = 'perplexity',
// only for test