fix(native): expose tapped audio stats (#10524)

Need to encode the audio based on the sample's sample rate & channels.
Also fixed that global audio tap not receiving any samples at all.
This commit is contained in:
pengx17
2025-02-28 13:24:02 +00:00
parent 61541a2d15
commit bab4a07c9f
10 changed files with 152 additions and 184 deletions

View File

@@ -98,8 +98,8 @@ export async function gemini(
try {
// Upload the audio file
uploadResult = await fileManager.uploadFile(audioFilePath, {
mimeType: 'audio/wav',
displayName: 'audio_transcription.wav',
mimeType: 'audio/mp3',
displayName: 'audio_transcription.mp3',
});
console.log('File uploaded:', uploadResult.file.uri);

View File

@@ -4,6 +4,8 @@ import { createServer } from 'node:http';
import {
type Application,
type AudioTapStream,
Bitrate,
Mp3Encoder,
ShareableContent,
type TappableApplication,
} from '@affine/native';
@@ -15,7 +17,6 @@ import fs from 'fs-extra';
import { Server } from 'socket.io';
import { gemini, type TranscriptionResult } from './gemini';
import { WavWriter } from './wav-writer';
// Constants
const RECORDING_DIR = './recordings';
@@ -51,6 +52,7 @@ interface RecordingMetadata {
recordingEndTime: number;
recordingDuration: number;
sampleRate: number;
channels: number;
totalSamples: number;
}
@@ -100,7 +102,9 @@ app.use(
return res.status(400).json({ error: 'Invalid folder name format' });
}
if (req.path.endsWith('.wav')) {
if (req.path.endsWith('.mp3')) {
res.setHeader('Content-Type', 'audio/mpeg');
} else if (req.path.endsWith('.wav')) {
res.setHeader('Content-Type', 'audio/wav');
} else if (req.path.endsWith('.png')) {
res.setHeader('Content-Type', 'image/png');
@@ -123,19 +127,25 @@ async function saveRecording(recording: Recording): Promise<string | null> {
const recordingEndTime = Date.now();
const recordingDuration = (recordingEndTime - recording.startTime) / 1000;
const expectedSamples = recordingDuration * 44100;
// Get the actual sample rate from the stream's audio stats
const actualSampleRate = recording.stream.sampleRate;
const channelCount = recording.stream.channels;
const expectedSamples = recordingDuration * actualSampleRate;
console.log(`💾 Saving recording for ${app.name}:`);
console.log(`- Process ID: ${app.processId}`);
console.log(`- Bundle ID: ${app.bundleIdentifier}`);
console.log(`- Actual duration: ${recordingDuration.toFixed(2)}s`);
console.log(`- Sample rate: ${actualSampleRate}Hz`);
console.log(`- Channels: ${channelCount}`);
console.log(`- Expected samples: ${Math.floor(expectedSamples)}`);
console.log(`- Actual samples: ${totalSamples}`);
console.log(
`- Sample ratio: ${(totalSamples / expectedSamples).toFixed(2)}`
);
// Create a buffer for the mono audio
// Create a buffer for the audio
const buffer = new Float32Array(totalSamples);
let offset = 0;
recording.buffers.forEach(buf => {
@@ -150,28 +160,33 @@ async function saveRecording(recording: Recording): Promise<string | null> {
const recordingDir = `${RECORDING_DIR}/${baseFilename}`;
await fs.ensureDir(recordingDir);
const wavFilename = `${recordingDir}/recording.wav`;
const transcriptionWavFilename = `${recordingDir}/transcription.wav`;
const mp3Filename = `${recordingDir}/recording.mp3`;
const transcriptionMp3Filename = `${recordingDir}/transcription.mp3`;
const metadataFilename = `${recordingDir}/metadata.json`;
const iconFilename = `${recordingDir}/icon.png`;
// Save high-quality WAV file for playback (44.1kHz)
console.log(`📝 Writing high-quality WAV file to ${wavFilename}`);
const writer = new WavWriter(wavFilename, { targetSampleRate: 44100 });
writer.write(buffer);
await writer.end();
console.log('✅ High-quality WAV file written successfully');
// Save low-quality WAV file for transcription (8kHz)
console.log(
`📝 Writing transcription WAV file to ${transcriptionWavFilename}`
);
const transcriptionWriter = new WavWriter(transcriptionWavFilename, {
targetSampleRate: 8000,
// Save MP3 file with the actual sample rate from the stream
console.log(`📝 Writing MP3 file to ${mp3Filename}`);
const mp3Encoder = new Mp3Encoder({
channels: channelCount,
sampleRate: actualSampleRate,
});
transcriptionWriter.write(buffer);
await transcriptionWriter.end();
console.log('✅ Transcription WAV file written successfully');
const mp3Data = mp3Encoder.encode(buffer);
await fs.writeFile(mp3Filename, mp3Data);
console.log('✅ MP3 file written successfully');
// Save low-quality MP3 file for transcription (8kHz)
console.log(
`📝 Writing transcription MP3 file to ${transcriptionMp3Filename}`
);
const transcriptionMp3Encoder = new Mp3Encoder({
channels: channelCount,
bitrate: Bitrate.Kbps8,
sampleRate: actualSampleRate,
});
const transcriptionMp3Data = transcriptionMp3Encoder.encode(buffer);
await fs.writeFile(transcriptionMp3Filename, transcriptionMp3Data);
console.log('✅ Transcription MP3 file written successfully');
// Save app icon if available
if (app.icon) {
@@ -181,7 +196,7 @@ async function saveRecording(recording: Recording): Promise<string | null> {
}
console.log(`📝 Writing metadata to ${metadataFilename}`);
// Save metadata (without icon)
// Save metadata with the actual sample rate from the stream
const metadata: RecordingMetadata = {
appName: app.name,
bundleIdentifier: app.bundleIdentifier,
@@ -189,7 +204,8 @@ async function saveRecording(recording: Recording): Promise<string | null> {
recordingStartTime: recording.startTime,
recordingEndTime,
recordingDuration,
sampleRate: 44100,
sampleRate: actualSampleRate,
channels: channelCount,
totalSamples,
};
@@ -296,7 +312,7 @@ async function stopRecording(processId: number) {
// File management
async function getRecordings(): Promise<
{
wav: string;
mp3: string;
metadata?: RecordingMetadata;
transcription?: TranscriptionMetadata;
}[]
@@ -340,7 +356,7 @@ async function getRecordings(): Promise<
if (transcriptionExists) {
transcription = await fs.readJson(transcriptionPath);
} else {
// If transcription.wav exists but no transcription.json, it means transcription is available but not started
// If transcription.mp3 exists but no transcription.json, it means transcription is available but not started
transcription = {
transcriptionStartTime: 0,
transcriptionEndTime: 0,
@@ -352,7 +368,7 @@ async function getRecordings(): Promise<
}
return {
wav: dir,
mp3: dir,
metadata,
transcription,
};
@@ -402,21 +418,21 @@ async function setupRecordingsWatcher() {
// Handle file events
fsWatcher
.on('add', async path => {
if (path.endsWith('.wav') || path.endsWith('.json')) {
if (path.endsWith('.mp3') || path.endsWith('.json')) {
console.log(`📝 File added: ${path}`);
const files = await getRecordings();
io.emit('apps:saved', { recordings: files });
}
})
.on('change', async path => {
if (path.endsWith('.wav') || path.endsWith('.json')) {
if (path.endsWith('.mp3') || path.endsWith('.json')) {
console.log(`📝 File changed: ${path}`);
const files = await getRecordings();
io.emit('apps:saved', { recordings: files });
}
})
.on('unlink', async path => {
if (path.endsWith('.wav') || path.endsWith('.json')) {
if (path.endsWith('.mp3') || path.endsWith('.json')) {
console.log(`🗑️ File removed: ${path}`);
const files = await getRecordings();
io.emit('apps:saved', { recordings: files });
@@ -702,11 +718,11 @@ app.post(
// Check if directory exists
await fs.access(recordingDir);
const transcriptionWavPath = `${recordingDir}/transcription.wav`;
const transcriptionMp3Path = `${recordingDir}/transcription.mp3`;
const transcriptionMetadataPath = `${recordingDir}/transcription.json`;
// Check if transcription file exists
await fs.access(transcriptionWavPath);
await fs.access(transcriptionMp3Path);
// Create initial transcription metadata
const initialMetadata: TranscriptionMetadata = {
@@ -719,7 +735,7 @@ app.post(
// Notify clients that transcription has started
io.emit('apps:recording-transcription-start', { filename: foldername });
const transcription = await gemini(transcriptionWavPath, {
const transcription = await gemini(transcriptionMp3Path, {
mode: 'transcript',
});

View File

@@ -1,125 +0,0 @@
import fs from 'fs-extra';
interface WavWriterConfig {
targetSampleRate?: number;
}
export class WavWriter {
private readonly file: fs.WriteStream;
private readonly originalSampleRate: number = 44100;
private readonly targetSampleRate: number;
private readonly numChannels = 1; // The audio is mono
private samplesWritten = 0;
private readonly tempFilePath: string;
private readonly finalFilePath: string;
constructor(finalPath: string, config: WavWriterConfig = {}) {
this.finalFilePath = finalPath;
this.tempFilePath = finalPath + '.tmp';
this.targetSampleRate = config.targetSampleRate ?? this.originalSampleRate;
this.file = fs.createWriteStream(this.tempFilePath);
this.writeHeader(); // Always write header immediately
}
private writeHeader() {
const buffer = Buffer.alloc(44); // WAV header is 44 bytes
// RIFF chunk descriptor
buffer.write('RIFF', 0);
buffer.writeUInt32LE(36, 4); // Initial file size - 8 (will be updated later)
buffer.write('WAVE', 8);
// fmt sub-chunk
buffer.write('fmt ', 12);
buffer.writeUInt32LE(16, 16); // Subchunk1Size (16 for PCM)
buffer.writeUInt16LE(3, 20); // AudioFormat (3 for IEEE float)
buffer.writeUInt16LE(this.numChannels, 22); // NumChannels
buffer.writeUInt32LE(this.targetSampleRate, 24); // SampleRate
buffer.writeUInt32LE(this.targetSampleRate * this.numChannels * 4, 28); // ByteRate
buffer.writeUInt16LE(this.numChannels * 4, 32); // BlockAlign
buffer.writeUInt16LE(32, 34); // BitsPerSample (32 for float)
// data sub-chunk
buffer.write('data', 36);
buffer.writeUInt32LE(0, 40); // Initial data size (will be updated later)
this.file.write(buffer);
}
private resample(samples: Float32Array): Float32Array {
const ratio = this.originalSampleRate / this.targetSampleRate;
const newLength = Math.floor(samples.length / ratio);
const result = new Float32Array(newLength);
for (let i = 0; i < newLength; i++) {
const position = i * ratio;
const index = Math.floor(position);
const fraction = position - index;
// Linear interpolation between adjacent samples
if (index + 1 < samples.length) {
result[i] =
samples[index] * (1 - fraction) + samples[index + 1] * fraction;
} else {
result[i] = samples[index];
}
}
return result;
}
write(samples: Float32Array) {
// Resample the input samples
const resampledData = this.resample(samples);
// Create a buffer with the correct size (4 bytes per float)
const buffer = Buffer.alloc(resampledData.length * 4);
// Write each float value properly
for (let i = 0; i < resampledData.length; i++) {
buffer.writeFloatLE(resampledData[i], i * 4);
}
this.file.write(buffer);
this.samplesWritten += resampledData.length;
}
async end(): Promise<void> {
return new Promise<void>((resolve, reject) => {
this.file.end(() => {
void this.updateHeaderAndCleanup().then(resolve).catch(reject);
});
});
}
private async updateHeaderAndCleanup(): Promise<void> {
// Read the entire temporary file
const data = await fs.promises.readFile(this.tempFilePath);
// Update the header with correct sizes
const dataSize = this.samplesWritten * 4;
const fileSize = dataSize + 36;
data.writeUInt32LE(fileSize, 4); // Update RIFF chunk size
data.writeUInt32LE(dataSize, 40); // Update data chunk size
// Write the updated file
await fs.promises.writeFile(this.finalFilePath, data);
// Clean up temp file
await fs.promises.unlink(this.tempFilePath);
}
}
/**
* Creates a Buffer from Float32Array audio data
* @param float32Array - The Float32Array containing audio samples
* @returns FileData - The audio data as a Buffer
*/
export function FileData(float32Array: Float32Array): Buffer {
const buffer = Buffer.alloc(float32Array.length * 4); // 4 bytes per float
for (let i = 0; i < float32Array.length; i++) {
buffer.writeFloatLE(float32Array[i], i * 4);
}
return buffer;
}

View File

@@ -584,7 +584,8 @@ export function SavedRecordingItem({
>(null);
const metadata = recording.metadata;
const fileName = recording.wav;
// Ensure we have a valid filename, fallback to an empty string if undefined
const fileName = recording.mp3 || '';
const recordingDate = metadata
? new Date(metadata.recordingStartTime).toLocaleString()
: 'Unknown date';
@@ -626,7 +627,12 @@ export function SavedRecordingItem({
const processAudioData = React.useCallback(async () => {
try {
const response = await fetch(`/api/recordings/${fileName}/recording.wav`);
// Check if fileName is empty
if (!fileName) {
throw new Error('Invalid recording filename');
}
const response = await fetch(`/api/recordings/${fileName}/recording.mp3`);
if (!response.ok) {
throw new Error(
`Failed to fetch audio file (${response.status}): ${response.statusText}`
@@ -741,7 +747,12 @@ export function SavedRecordingItem({
setError(null); // Clear any previous errors
try {
const response = await fetch(`/api/recordings/${recording.wav}`, {
// Check if filename is valid
if (!recording.mp3) {
throw new Error('Invalid recording filename');
}
const response = await fetch(`/api/recordings/${recording.mp3}`, {
method: 'DELETE',
});
@@ -765,7 +776,7 @@ export function SavedRecordingItem({
} finally {
setIsDeleting(false);
}
}, [recording.wav]);
}, [recording.mp3]);
const handleDeleteClick = React.useCallback(() => {
void handleDelete().catch(err => {
@@ -779,7 +790,7 @@ export function SavedRecordingItem({
socket.on(
'apps:recording-transcription-start',
(data: { filename: string }) => {
if (data.filename === recording.wav) {
if (recording.mp3 && data.filename === recording.mp3) {
setTranscriptionError(null);
}
}
@@ -793,7 +804,7 @@ export function SavedRecordingItem({
transcription?: string;
error?: string;
}) => {
if (data.filename === recording.wav && !data.success) {
if (recording.mp3 && data.filename === recording.mp3 && !data.success) {
setTranscriptionError(data.error || 'Transcription failed');
}
}
@@ -803,12 +814,17 @@ export function SavedRecordingItem({
socket.off('apps:recording-transcription-start');
socket.off('apps:recording-transcription-end');
};
}, [recording.wav]);
}, [recording.mp3]);
const handleTranscribe = React.useCallback(async () => {
try {
// Check if filename is valid
if (!recording.mp3) {
throw new Error('Invalid recording filename');
}
const response = await fetch(
`/api/recordings/${recording.wav}/transcribe`,
`/api/recordings/${recording.mp3}/transcribe`,
{
method: 'POST',
}
@@ -823,7 +839,7 @@ export function SavedRecordingItem({
err instanceof Error ? err.message : 'Failed to start transcription'
);
}
}, [recording.wav]);
}, [recording.mp3]);
return (
<div className="bg-white rounded-lg shadow-sm hover:shadow-md transition-all duration-300 overflow-hidden mb-3 border border-gray-100 hover:border-gray-200">
@@ -854,7 +870,7 @@ export function SavedRecordingItem({
/>
<audio
ref={audioRef}
src={`/api/recordings/${fileName}/recording.wav`}
src={fileName ? `/api/recordings/${fileName}/recording.mp3` : ''}
preload="metadata"
className="hidden"
/>

View File

@@ -34,7 +34,7 @@ export function SavedRecordings(): React.ReactElement {
return (
<div className="space-y-1">
{recordings.map(recording => (
<SavedRecordingItem key={recording.wav} recording={recording} />
<SavedRecordingItem key={recording.mp3} recording={recording} />
))}
</div>
);

View File

@@ -27,8 +27,10 @@ export interface RecordingMetadata {
recordingEndTime: number;
recordingDuration: number;
sampleRate: number;
channels: number;
totalSamples: number;
icon?: Uint8Array;
mp3: string;
}
export interface TranscriptionMetadata {
@@ -49,7 +51,7 @@ export interface TranscriptionMetadata {
}
export interface SavedRecording {
wav: string;
mp3: string;
metadata?: RecordingMetadata;
transcription?: TranscriptionMetadata;
}

View File

@@ -19,6 +19,8 @@ export declare class ApplicationStateChangedSubscriber {
export declare class AudioTapStream {
stop(): void
get sampleRate(): number
get channels(): number
}
export declare class DocStorage {

View File

@@ -54,9 +54,9 @@ impl CATapDescription {
.as_slice(),
);
let obj: *mut AnyObject =
unsafe { msg_send![obj, initStereoMixdownOfProcesses: &*processes_array] };
unsafe { msg_send![obj, initStereoGlobalTapButExcludeProcesses: &*processes_array] };
if obj.is_null() {
return Err(CoreAudioError::InitStereoMixdownOfProcessesFailed);
return Err(CoreAudioError::InitStereoGlobalTapButExcludeProcessesFailed);
}
Ok(Self { inner: obj })

View File

@@ -24,6 +24,8 @@ pub enum CoreAudioError {
AllocCATapDescriptionFailed,
#[error("Call initStereoMixdownOfProcesses on CATapDescription failed")]
InitStereoMixdownOfProcessesFailed,
#[error("Call initStereoGlobalTapButExcludeProcesses on CATapDescription failed")]
InitStereoGlobalTapButExcludeProcessesFailed,
#[error("Get UUID on CATapDescription failed")]
GetCATapDescriptionUUIDFailed,
#[error("Get mute behavior on CATapDescription failed")]

View File

@@ -29,6 +29,7 @@ use napi_derive::napi;
use objc2::{runtime::AnyObject, Encode, Encoding, RefEncode};
use crate::{
audio_stream_basic_desc::read_audio_stream_basic_description,
ca_tap_description::CATapDescription, device::get_device_uid, error::CoreAudioError,
queue::create_audio_tap_queue, screen_capture_kit::TappableApplication,
};
@@ -82,9 +83,17 @@ unsafe impl RefEncode for AudioBufferList {
const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
}
// Audio statistics structure to track audio format information
#[derive(Clone, Copy, Debug)]
pub struct AudioStats {
pub sample_rate: f64,
pub channels: u32,
}
pub struct AggregateDevice {
pub tap_id: AudioObjectID,
pub id: AudioObjectID,
pub audio_stats: Option<AudioStats>,
}
impl AggregateDevice {
@@ -118,6 +127,7 @@ impl AggregateDevice {
Ok(Self {
tap_id,
id: aggregate_device_id,
audio_stats: None,
})
}
@@ -149,6 +159,7 @@ impl AggregateDevice {
Ok(Self {
tap_id,
id: aggregate_device_id,
audio_stats: None,
})
}
@@ -181,6 +192,7 @@ impl AggregateDevice {
Ok(Self {
tap_id,
id: aggregate_device_id,
audio_stats: None,
})
}
@@ -188,6 +200,22 @@ impl AggregateDevice {
&mut self,
audio_stream_callback: Arc<ThreadsafeFunction<Float32Array, (), Float32Array, true>>,
) -> Result<AudioTapStream> {
// Read and log the audio format before starting the device
let mut audio_stats = AudioStats {
sample_rate: 44100.0,
channels: 1, // Always set to 1 channel (mono)
};
if let Ok(audio_format) = read_audio_stream_basic_description(self.tap_id) {
// Store the audio format information
audio_stats.sample_rate = audio_format.0.mSampleRate;
// Always use 1 channel regardless of what the system reports
audio_stats.channels = 1;
}
self.audio_stats = Some(audio_stats);
let audio_stats_clone = audio_stats;
let queue = create_audio_tap_queue();
let mut in_proc_id: AudioDeviceIOProcID = None;
@@ -221,18 +249,33 @@ impl AggregateDevice {
let samples: &[f32] =
unsafe { std::slice::from_raw_parts(mData.cast::<f32>(), total_samples) };
// Convert to mono if needed
let mono_samples: Vec<f32> = if *mNumberChannels > 1 {
samples
.chunks(*mNumberChannels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / *mNumberChannels as f32)
.collect()
} else {
samples.to_vec()
};
// Check the channel count and data format
let channel_count = *mNumberChannels as usize;
// Process the audio based on channel count
let mut processed_samples: Vec<f32>;
if channel_count > 1 {
// For stereo, samples are interleaved: [L, R, L, R, ...]
// We need to average each pair to get mono
let frame_count = total_samples / channel_count;
processed_samples = Vec::with_capacity(frame_count);
for i in 0..frame_count {
let mut frame_sum = 0.0;
for c in 0..channel_count {
frame_sum += samples[i * channel_count + c];
}
processed_samples.push(frame_sum / (channel_count as f32));
}
} else {
// Already mono, just copy the samples
processed_samples = samples.to_vec();
}
// Pass the processed samples to the callback
audio_stream_callback.call(
Ok(mono_samples.into()),
Ok(processed_samples.into()),
ThreadsafeFunctionCallMode::NonBlocking,
);
}
@@ -266,6 +309,7 @@ impl AggregateDevice {
device_id: self.id,
in_proc_id,
stop_called: false,
audio_stats: audio_stats_clone,
})
}
@@ -353,6 +397,7 @@ pub struct AudioTapStream {
device_id: AudioObjectID,
in_proc_id: AudioDeviceIOProcID,
stop_called: bool,
audio_stats: AudioStats,
}
#[napi]
@@ -381,6 +426,16 @@ impl AudioTapStream {
}
Ok(())
}
#[napi(getter)]
pub fn get_sample_rate(&self) -> f64 {
self.audio_stats.sample_rate
}
#[napi(getter)]
pub fn get_channels(&self) -> u32 {
self.audio_stats.channels
}
}
fn cfstring_from_bytes_with_nul(bytes: &'static [u8]) -> CFString {