feat(server): audio transcription (#10733)

This commit is contained in:
darkskygit
2025-03-20 07:12:27 +00:00
parent bd5d930490
commit c16ae2d5b4
19 changed files with 712 additions and 58 deletions

View File

@@ -14,7 +14,7 @@
"test": "ava --concurrency 1 --serial",
"test:copilot": "ava \"src/__tests__/**/copilot-*.spec.ts\"",
"test:coverage": "c8 ava --concurrency 1 --serial",
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/**/copilot-*.spec.ts\"",
"test:copilot:coverage": "c8 ava --timeout=5m \"src/__tests__/copilot-*.spec.ts\"",
"e2e": "cross-env TEST_MODE=e2e ava",
"e2e:coverage": "cross-env TEST_MODE=e2e c8 ava",
"data-migration": "cross-env NODE_ENV=development r ./src/data/index.ts",

View File

@@ -1,4 +1,10 @@
import { AiJobStatus, AiJobType, PrismaClient } from '@prisma/client';
import {
AiJobStatus,
AiJobType,
PrismaClient,
User,
Workspace,
} from '@prisma/client';
import ava, { TestFn } from 'ava';
import { Config } from '../../base';
@@ -28,8 +34,15 @@ test.before(async t => {
t.context.module = module;
});
let user: User;
let workspace: Workspace;
test.beforeEach(async t => {
await t.context.module.initTestingDB();
user = await t.context.user.create({
email: 'test@affine.pro',
});
workspace = await t.context.workspace.create(user.id);
});
test.after(async t => {
@@ -37,11 +50,6 @@ test.after(async t => {
});
test('should create a copilot job', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const workspace = await t.context.workspace.create(user.id);
const data = {
workspaceId: workspace.id,
blobId: 'blob-id',
@@ -71,10 +79,6 @@ test('should get null for non-exist job', async t => {
});
test('should update job', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const workspace = await t.context.workspace.create(user.id);
const { id: jobId } = await t.context.copilotJob.create({
workspaceId: workspace.id,
blobId: 'blob-id',
@@ -97,10 +101,6 @@ test('should update job', async t => {
});
test('should claim job', async t => {
const user = await t.context.user.create({
email: 'test@affine.pro',
});
const workspace = await t.context.workspace.create(user.id);
const { id: jobId } = await t.context.copilotJob.create({
workspaceId: workspace.id,
blobId: 'blob-id',

View File

@@ -697,6 +697,10 @@ export const USER_FRIENDLY_ERRORS = {
type: 'action_forbidden',
message: `Embedding feature not available, you may need to install pgvector extension to your database`,
},
copilot_transcription_job_exists: {
type: 'bad_request',
message: () => 'Transcription job already exists',
},
// Quota & Limit errors
blob_quota_exceeded: {

View File

@@ -753,6 +753,12 @@ export class CopilotEmbeddingUnavailable extends UserFriendlyError {
}
}
export class CopilotTranscriptionJobExists extends UserFriendlyError {
constructor(message?: string) {
super('bad_request', 'copilot_transcription_job_exists', message);
}
}
export class BlobQuotaExceeded extends UserFriendlyError {
constructor(message?: string) {
super('quota_exceeded', 'blob_quota_exceeded', message);
@@ -1000,6 +1006,7 @@ export enum ErrorNames {
COPILOT_FAILED_TO_MODIFY_CONTEXT,
COPILOT_FAILED_TO_MATCH_CONTEXT,
COPILOT_EMBEDDING_UNAVAILABLE,
COPILOT_TRANSCRIPTION_JOB_EXISTS,
BLOB_QUOTA_EXCEEDED,
STORAGE_QUOTA_EXCEEDED,
MEMBER_QUOTA_EXCEEDED,

View File

@@ -31,6 +31,10 @@ import {
} from './resolver';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
CopilotTranscriptionResolver,
CopilotTranscriptionService,
} from './transcript';
import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
@@ -58,6 +62,9 @@ registerCopilotProvider(PerplexityProvider);
CopilotContextResolver,
CopilotContextService,
CopilotContextDocJob,
// transcription
CopilotTranscriptionService,
CopilotTranscriptionResolver,
],
controllers: [CopilotController],
contributesTo: ServerFeature.Copilot,

View File

@@ -101,53 +101,52 @@ export class GoogleProvider implements CopilotTextToTextProvider {
return undefined;
}
protected chatToGPTMessage(
protected async chatToGPTMessage(
messages: PromptMessage[]
): [string | undefined, ChatMessage[]] {
): Promise<[string | undefined, ChatMessage[]]> {
let system =
messages[0]?.role === 'system' ? messages.shift()?.content : undefined;
// filter redundant fields
const msgs = messages
.filter(m => m.role !== 'system')
.map(({ role, content, attachments, params }) => {
content = content.trim();
role = role as 'user' | 'assistant';
const mimetype = params?.mimetype;
if (Array.isArray(attachments)) {
const contents: (TextPart | FilePart)[] = [];
if (content.length) {
contents.push({
type: 'text',
text: content,
});
}
contents.push(
...attachments
.map(url => {
if (SIMPLE_IMAGE_URL_REGEX.test(url)) {
const mimeType =
typeof mimetype === 'string'
? mimetype
: this.inferMimeType(url);
if (mimeType) {
const data = url.startsWith('data:') ? url : new URL(url);
return {
type: 'file' as const,
data,
mimeType,
};
}
}
return undefined;
})
.filter(c => !!c)
);
return { role, content: contents } as ChatMessage;
} else {
return { role, content } as ChatMessage;
const msgs: ChatMessage[] = [];
for (let { role, content, attachments, params } of messages.filter(
m => m.role !== 'system'
)) {
content = content.trim();
role = role as 'user' | 'assistant';
const mimetype = params?.mimetype;
if (Array.isArray(attachments)) {
const contents: (TextPart | FilePart)[] = [];
if (content.length) {
contents.push({
type: 'text',
text: content,
});
}
});
for (const url of attachments) {
if (SIMPLE_IMAGE_URL_REGEX.test(url)) {
const mimeType =
typeof mimetype === 'string' ? mimetype : this.inferMimeType(url);
if (mimeType) {
const data = url.startsWith('data:')
? await fetch(url).then(r => r.arrayBuffer())
: new URL(url);
contents.push({
type: 'file' as const,
data,
mimeType,
});
}
}
}
msgs.push({ role, content: contents } as ChatMessage);
} else {
msgs.push({ role, content });
}
}
return [system, msgs];
}
@@ -237,7 +236,7 @@ export class GoogleProvider implements CopilotTextToTextProvider {
try {
metrics.ai.counter('chat_text_calls').add(1, { model });
const [system, msgs] = this.chatToGPTMessage(messages);
const [system, msgs] = await this.chatToGPTMessage(messages);
const { text } = await generateText({
model: this.instance(model, {
@@ -266,7 +265,7 @@ export class GoogleProvider implements CopilotTextToTextProvider {
try {
metrics.ai.counter('chat_text_stream_calls').add(1, { model });
const [system, msgs] = this.chatToGPTMessage(messages);
const [system, msgs] = await this.chatToGPTMessage(messages);
const { textStream } = streamText({
model: this.instance(model),

View File

@@ -0,0 +1,2 @@
export { CopilotTranscriptionResolver } from './resolver';
export { CopilotTranscriptionService } from './service';

View File

@@ -0,0 +1,134 @@
import { Injectable } from '@nestjs/common';
import {
Args,
Field,
ID,
Mutation,
ObjectType,
Parent,
registerEnumType,
ResolveField,
Resolver,
} from '@nestjs/graphql';
import { AiJobStatus } from '@prisma/client';
import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import type { FileUpload } from '../../../base';
import { CurrentUser } from '../../../core/auth';
import { AccessController } from '../../../core/permission';
import { CopilotType } from '../resolver';
import { CopilotTranscriptionService, TranscriptionJob } from './service';
import type { TranscriptionItem, TranscriptionPayload } from './types';
registerEnumType(AiJobStatus, {
name: 'AiJobStatus',
});
@ObjectType()
class TranscriptionItemType implements TranscriptionItem {
@Field(() => String)
speaker!: string;
@Field(() => String)
start!: string;
@Field(() => String)
end!: string;
@Field(() => String)
transcription!: string;
}
@ObjectType()
class TranscriptionResultType implements TranscriptionPayload {
@Field(() => ID)
id!: string;
@Field(() => [TranscriptionItemType], { nullable: true })
transcription!: TranscriptionItemType[] | null;
@Field(() => String, { nullable: true })
summary!: string | null;
@Field(() => AiJobStatus)
status!: AiJobStatus;
}
@Injectable()
@Resolver(() => CopilotType)
export class CopilotTranscriptionResolver {
constructor(
private readonly ac: AccessController,
private readonly service: CopilotTranscriptionService
) {}
private handleJobResult(
job: TranscriptionJob | null
): TranscriptionResultType | null {
if (job) {
const { transcription: ret, status } = job;
return {
id: job.id,
transcription: ret?.transcription || null,
summary: ret?.summary || null,
status,
};
}
return null;
}
@Mutation(() => TranscriptionResultType, { nullable: true })
async submitAudioTranscription(
@CurrentUser() user: CurrentUser,
@Args('workspaceId') workspaceId: string,
@Args('blobId') blobId: string,
@Args({ name: 'blob', type: () => GraphQLUpload })
blob: FileUpload
): Promise<TranscriptionResultType | null> {
await this.ac
.user(user.id)
.workspace(workspaceId)
.allowLocal()
.assert('Workspace.Copilot');
const job = await this.service.submitTranscriptionJob(
user.id,
workspaceId,
blobId,
blob
);
return this.handleJobResult(job);
}
@Mutation(() => TranscriptionResultType, { nullable: true })
async claimAudioTranscription(
@CurrentUser() user: CurrentUser,
@Args('jobId') jobId: string
): Promise<TranscriptionResultType | null> {
const job = await this.service.claimTranscriptionJob(user.id, jobId);
return this.handleJobResult(job);
}
@ResolveField(() => [TranscriptionResultType], {})
async audioTranscription(
@Parent() copilot: CopilotType,
@CurrentUser() user: CurrentUser,
@Args('jobId', { nullable: true })
jobId: string
): Promise<TranscriptionResultType | null> {
if (!copilot.workspaceId) return null;
await this.ac
.user(user.id)
.workspace(copilot.workspaceId)
.allowLocal()
.assert('Workspace.Copilot');
const job = await this.service.queryTranscriptionJob(
user.id,
copilot.workspaceId,
jobId
);
return this.handleJobResult(job);
}
}

View File

@@ -0,0 +1,206 @@
import { Injectable } from '@nestjs/common';
import { AiJobStatus, AiJobType } from '@prisma/client';
import {
CopilotPromptNotFound,
CopilotTranscriptionJobExists,
type FileUpload,
JobQueue,
NoCopilotProviderAvailable,
OnJob,
} from '../../../base';
import { Models } from '../../../models';
import { PromptService } from '../prompt';
import { CopilotProviderService } from '../providers';
import { CopilotStorage } from '../storage';
import {
CopilotCapability,
CopilotTextProvider,
PromptMessage,
} from '../types';
import {
TranscriptionPayload,
TranscriptionSchema,
TranscriptPayloadSchema,
} from './types';
import { readStream } from './utils';
export type TranscriptionJob = {
id: string;
status: AiJobStatus;
transcription?: TranscriptionPayload;
};
@Injectable()
export class CopilotTranscriptionService {
constructor(
private readonly models: Models,
private readonly job: JobQueue,
private readonly storage: CopilotStorage,
private readonly prompt: PromptService,
private readonly provider: CopilotProviderService
) {}
async submitTranscriptionJob(
userId: string,
workspaceId: string,
blobId: string,
blob: FileUpload
): Promise<TranscriptionJob> {
if (await this.models.copilotJob.has(workspaceId, blobId)) {
throw new CopilotTranscriptionJobExists();
}
const { id: jobId, status } = await this.models.copilotJob.create({
workspaceId,
blobId,
createdBy: userId,
type: AiJobType.transcription,
});
const buffer = await readStream(blob.createReadStream());
const url = await this.storage.put(userId, workspaceId, blobId, buffer);
await this.models.copilotJob.update(jobId, {
status: AiJobStatus.running,
});
await this.job.add(
'copilot.transcript.submit',
{
jobId,
url,
mimeType: blob.mimetype,
},
// retry 3 times
{ removeOnFail: 3 }
);
return { id: jobId, status };
}
async claimTranscriptionJob(
userId: string,
jobId: string
): Promise<TranscriptionJob | null> {
const status = await this.models.copilotJob.claim(jobId, userId);
if (status === AiJobStatus.claimed) {
const transcription = await this.models.copilotJob.getPayload(
jobId,
TranscriptPayloadSchema
);
return { id: jobId, transcription, status };
}
return null;
}
async queryTranscriptionJob(
userId: string,
workspaceId: string,
jobId: string
) {
const job = await this.models.copilotJob.getWithUser(
userId,
workspaceId,
jobId,
AiJobType.transcription
);
if (!job) {
return null;
}
const ret: TranscriptionJob = { id: job.id, status: job.status };
const payload = TranscriptPayloadSchema.safeParse(job.payload);
if (payload.success) {
ret.transcription = payload.data;
}
return ret;
}
private async getProvider(model: string): Promise<CopilotTextProvider> {
let provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new NoCopilotProviderAvailable();
}
return provider;
}
private async chatWithPrompt(
promptName: string,
message: Partial<PromptMessage>
): Promise<string> {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new CopilotPromptNotFound({ name: promptName });
}
const provider = await this.getProvider(prompt.model);
return provider.generateText(
[...prompt.finish({}), { role: 'user', content: '', ...message }],
prompt.model
);
}
private cleanupResponse(response: string): string {
return response
.replace(/```[\w\s]+\n/g, '')
.replace(/\n```/g, '')
.trim();
}
@OnJob('copilot.transcript.submit')
async transcriptAudio({
jobId,
url,
mimeType,
}: Jobs['copilot.transcript.submit']) {
const result = await this.chatWithPrompt('Transcript audio', {
attachments: [url],
params: { mimetype: mimeType },
});
const transcription = TranscriptionSchema.parse(
JSON.parse(this.cleanupResponse(result))
);
await this.models.copilotJob.update(jobId, { payload: { transcription } });
await this.job.add(
'copilot.summary.submit',
{
jobId,
},
// retry 3 times
{ removeOnFail: 3 }
);
}
@OnJob('copilot.summary.submit')
async summaryTranscription({ jobId }: Jobs['copilot.summary.submit']) {
const payload = await this.models.copilotJob.getPayload(
jobId,
TranscriptPayloadSchema
);
if (payload.transcription) {
const content = payload.transcription
.map(t => t.transcription)
.join('\n');
const result = await this.chatWithPrompt('Summary', { content });
payload.summary = this.cleanupResponse(result);
await this.models.copilotJob.update(jobId, { payload });
} else {
await this.models.copilotJob.update(jobId, {
status: AiJobStatus.failed,
});
}
}
}

View File

@@ -0,0 +1,36 @@
import { z } from 'zod';
import { OneMB } from '../../../base';
const TranscriptionItemSchema = z.object({
speaker: z.string(),
start: z.string(),
end: z.string(),
transcription: z.string(),
});
export const TranscriptionSchema = z.array(TranscriptionItemSchema);
export const TranscriptPayloadSchema = z.object({
transcription: TranscriptionSchema.nullable().optional(),
summary: z.string().nullable().optional(),
});
export type TranscriptionItem = z.infer<typeof TranscriptionItemSchema>;
export type Transcription = z.infer<typeof TranscriptionSchema>;
export type TranscriptionPayload = z.infer<typeof TranscriptPayloadSchema>;
declare global {
interface Jobs {
'copilot.transcript.submit': {
jobId: string;
url: string;
mimeType: string;
};
'copilot.summary.submit': {
jobId: string;
};
}
}
export const MAX_TRANSCRIPTION_SIZE = 50 * OneMB;

View File

@@ -0,0 +1,11 @@
import { Readable } from 'node:stream';
import { readBufferWithLimit } from '../../../base';
import { MAX_TRANSCRIPTION_SIZE } from './types';
export function readStream(
readable: Readable,
maxSize = MAX_TRANSCRIPTION_SIZE
): Promise<Buffer> {
return readBufferWithLimit(readable, maxSize);
}

View File

@@ -18,6 +18,14 @@ input AddRemoveContextCategoryInput {
type: ContextCategories!
}
enum AiJobStatus {
claimed
failed
finished
pending
running
}
type AlreadyInSpaceDataType {
spaceId: String!
}
@@ -72,6 +80,8 @@ type ContextWorkspaceEmbeddingStatus {
}
type Copilot {
audioTranscription(jobId: String): [TranscriptionResultType!]!
"""Get the context list of a session"""
contexts(contextId: String, sessionId: String): [CopilotContext!]!
histories(docId: String, options: QueryChatHistoriesInput): [CopilotHistories!]!
@@ -409,6 +419,7 @@ enum ErrorNames {
COPILOT_QUOTA_EXCEEDED
COPILOT_SESSION_DELETED
COPILOT_SESSION_NOT_FOUND
COPILOT_TRANSCRIPTION_JOB_EXISTS
CUSTOMER_PORTAL_CREATE_FAILED
DOC_ACTION_DENIED
DOC_DEFAULT_ROLE_CAN_NOT_BE_OWNER
@@ -842,6 +853,7 @@ type Mutation {
cancelSubscription(idempotencyKey: String @deprecated(reason: "use header `Idempotency-Key`"), plan: SubscriptionPlan = Pro, workspaceId: String): SubscriptionType!
changeEmail(email: String!, token: String!): UserType!
changePassword(newPassword: String!, token: String!, userId: String): Boolean!
claimAudioTranscription(jobId: String!): TranscriptionResultType
"""Cleanup sessions"""
cleanupCopilotSession(options: DeleteSessionInput!): [String!]!
@@ -934,6 +946,7 @@ type Mutation {
sendVerifyChangeEmail(callbackUrl: String!, email: String!, token: String!): Boolean!
sendVerifyEmail(callbackUrl: String!): Boolean!
setBlob(blob: Upload!, workspaceId: String!): String!
submitAudioTranscription(blob: Upload!, blobId: String!, workspaceId: String!): TranscriptionResultType
"""Update a copilot prompt"""
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
@@ -1386,6 +1399,20 @@ enum SubscriptionVariant {
Onetime
}
type TranscriptionItemType {
end: String!
speaker: String!
start: String!
transcription: String!
}
type TranscriptionResultType {
id: ID!
status: AiJobStatus!
summary: String
transcription: [TranscriptionItemType!]
}
union UnionNotificationBodyType = InvitationAcceptedNotificationBodyType | InvitationBlockedNotificationBodyType | InvitationNotificationBodyType | MentionNotificationBodyType
type UnknownOauthProviderDataType {