import os
# 显式设置ffprobe的路径
ffprobe_path = '/usr/local/bin/ffprobe'
# 设置pydub的ffprobe路径
os.environ["PATH"] += os.pathsep + os.path.dirname(ffprobe_path)

import numpy as np
import soundfile as sf
import webrtcvad
from pydub import AudioSegment

def convert_mp3_to_wav(mp3_path, wav_path):
    """将MP3转换为16kHz单声道WAV（webrtcvad要求的格式）"""
    # 用pydub读取MP3并转换采样率和声道
    audio = AudioSegment.from_mp3(mp3_path)
    audio = audio.set_frame_rate(16000).set_channels(1)  # 转为16kHz单声道
    
    # 导出为临时WAV（确保格式正确）
    temp_wav = "temp_pydub.wav"
    audio.export(temp_wav, format="wav")
    
    # 用soundfile读取并重新保存为16位PCM（适配Python 3.13的文件操作）
    data, sr = sf.read(temp_wav)
    data_int16 = (data * 32767).astype(np.int16)  # 转换为16位整数格式
    sf.write(wav_path, data_int16, sr, subtype='PCM_16')
    
    # 清理临时文件（Python 3.13中os.remove兼容性不变）
    if os.path.exists(temp_wav):
        os.remove(temp_wav)
    return wav_path

def detect_voice_segments(wav_path, vad_mode=1, frame_duration=30):
    """使用webrtcvad检测人声区间（适配Python 3.13的类型处理）"""
    vad = webrtcvad.Vad(vad_mode)
    
    # 用soundfile读取WAV（Python 3.13中sf.read保持兼容）
    data, sample_rate = sf.read(wav_path, dtype='int16')
    assert sample_rate == 16000, "音频必须是16kHz采样率"
    
    # 计算帧参数
    frame_samples = int(sample_rate * frame_duration / 1000)
    frame_bytes = frame_samples * 2  # 16位PCM，每个采样点2字节
    
    # 转换为字节流（Python 3.13中numpy数组tobytes()方法兼容）
    pcm_data = data.tobytes()
    
    # 逐帧检测人声
    speech_frames = []
    for i in range(0, len(pcm_data), frame_bytes):
        frame = pcm_data[i:i+frame_bytes]
        if len(frame) < frame_bytes:
            break
        # webrtcvad在Python 3.13下需确保使用最新版本（>=2.0.10）
        is_speech = vad.is_speech(frame, sample_rate)
        speech_frames.append(is_speech)
    
    # 转换为时间区间（秒）
    segments = []
    start = None
    frame_interval = frame_duration / 1000  # 每帧时长（秒）
    
    for i, is_speech in enumerate(speech_frames):
        current_time = i * frame_interval
        if is_speech and start is None:
            start = current_time
        elif not is_speech and start is not None:
            segments.append((start, current_time))
            start = None
    
    # 处理最后一个区间
    if start is not None:
        segments.append((start, len(speech_frames) * frame_interval))
    
    return segments

def merge_close_segments(segments, min_gap=0.5):
    """合并间隔过小的人声区间"""
    if not segments:
        return []
    
    merged = [list(segments[0])]
    for current in segments[1:]:
        last = merged[-1]
        if current[0] - last[1] <= min_gap:
            last[1] = max(last[1], current[1])
        else:
            merged.append(list(current))
    return [(s, e) for s, e in merged]

def clip_mp3_voice(audio_path, output_path, min_segment_length=0.5, min_gap=0.5, vad_mode=1):
    """主函数：剪辑MP3中的口播部分"""
    temp_wav = "temp_vad.wav"
    try:
        # 转换音频格式
        convert_mp3_to_wav(audio_path, temp_wav)
        
        # 检测人声区间
        print("正在检测人声区间...")
        segments = detect_voice_segments(temp_wav, vad_mode=vad_mode)
        
        if not segments:
            print("未检测到人声片段")
            return False
        
        # 过滤过短片段
        segments = [(s, e) for s, e in segments if (e - s) >= min_segment_length]
        if not segments:
            print("所有人声片段都过短，已过滤")
            return False
        
        # 合并接近的区间
        merged_segments = merge_close_segments(segments, min_gap)
        print(f"检测到{len(merged_segments)}个人声片段，准备剪辑...")
        
        # 提取并拼接片段
        audio = AudioSegment.from_mp3(audio_path)
        result = AudioSegment.empty()
        
        for i, (start, end) in enumerate(merged_segments):
            print(f"提取片段 {i+1}/{len(merged_segments)}: {start:.2f}s - {end:.2f}s")
            start_ms = start * 1000
            end_ms = end * 1000
            result += audio[start_ms:end_ms]
        
        # 导出结果
        result.export(output_path, format="mp3")
        print(f"剪辑完成，输出文件: {output_path}")
        print(f"原始时长: {len(audio)/1000:.2f}s，保留时长: {len(result)/1000:.2f}s")
        return True
        
    finally:
        # 清理临时文件
        if os.path.exists(temp_wav):
            os.remove(temp_wav)

if __name__ == "__main__":
    input_mp3 = "../o_eng.mp3"       # 输入MP3文件路径
    output_mp3 = "output_voice.mp3"  # 输出文件路径
    
    if not os.path.exists(input_mp3):
        print(f"错误: 找不到文件 '{input_mp3}'")
    else:
        clip_mp3_voice(
            audio_path=input_mp3,
            output_path=output_mp3,
            min_segment_length=0.3,  # 最小片段长度（秒）
            min_gap=0.8,             # 合并间隙（秒）
            vad_mode=1               # 检测灵敏度（0-2）
        )

