NeMo speaker embedding model(TitaNet-L) 학습 코드
2025. 2. 3. 16:08ㆍ연구하기, 지식
- 공식문서에는 jupyter notebook 기준으로 나와 있지만 py 환경에 맞게 재구성.
- 각종 명령어들은 subprocess를 사용하여 실행.
- 자세한 데이터 구조는 an4를 직접 다운로드 후 참고.
import os
import glob
import subprocess
import tarfile
import wget
########################################
## 1. an4 실험 데이터 다운로드 ##
########################################
## *** 데이터 구조 ***
# ./data/an4/wav/an4clstk/화자별 폴더/음성데이터 13개(모두가 똑같은 스크립트를 읽음)
data_dir = os.path.join('path_your_data_home')
train_speakers_folder_dir = 'your_folder_name'
test_speakers_folder_dir = 'your_folder_name'
# an4_train_speakers_folder_dir = 'path_your_an4'
# os.makedirs(data_dir, exist_ok=True)
# # Download the dataset. This will take a few moments...
# print("******")
# if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):
# an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'
# an4_path = wget.download(an4_url, data_dir)
# print(f"Dataset downloaded at: {an4_path}")
# else:
# print("Tarfile already exists.")
# an4_path = data_dir + '/an4_sphere.tar.gz'
# # Untar and convert .sph to .wav (using sox)
# tar = tarfile.open(an4_path)
# tar.extractall(path=data_dir)
# print("Converting .sph to .wav...")
# sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)
# for sph_path in sph_list:
# wav_path = sph_path[:-4] + '.wav'
# cmd = ["sox", sph_path, wav_path]
# subprocess.run(cmd)
# print("Finished conversion.\n******")
#########################################
## 2. train txt 파일 만들기 ##
#########################################
command = f"find {data_dir}/{train_speakers_folder_dir} -iname '*.wav' > {data_dir}/{train_speakers_folder_dir}/train_all.txt"
subprocess.run(command, shell=True)
command = f'head -n 3 {data_dir}/{train_speakers_folder_dir}/train_all.txt'
subprocess.run(command, shell=True)
########################################################################
## 3. manifast 제작용 py 파일 만들기 및 train manifest 만들기 ##
########################################################################
working_dir = 'path_your_working'
# if not os.path.exists(f'{working_dir}/scripts'):
# print("Downloading necessary scripts")
# subprocess.run(f'mkdir -p {working_dir}/scripts/speaker_tasks', shell=True)
# subprocess.run(f'wget -P {working_dir}/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/speaker_tasks/filelist_to_manifest.py', shell=True)
# 직접 다운로드 : https://github.com/NVIDIA/NeMo/blob/main/scripts/speaker_tasks/filelist_to_manifest.py
subprocess.run(f'python {working_dir}/scripts/speaker_tasks/filelist_to_manifest.py --filelist {data_dir}/{train_speakers_folder_dir}/train_all.txt --id -2 --out {data_dir}/{train_speakers_folder_dir}/all_manifest.json --split', shell=True)
##############################################
## 4. test txt 및 manifest 만들기 ##
##############################################
subprocess.run(f'find {data_dir}/{test_speakers_folder_dir} -iname "*.wav" > {data_dir}/{test_speakers_folder_dir}/test_all.txt', shell=True)
subprocess.run(f'python {working_dir}/scripts/speaker_tasks/filelist_to_manifest.py --filelist {data_dir}/{test_speakers_folder_dir}/test_all.txt --id -2 --out {data_dir}/an4/wav/an4test_clstk/test.json', shell=True)
#########################################
## 5. manifest 파일 경로 설정 ##
#########################################
train_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/train.json')
validation_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/dev.json')
test_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/dev.json')
######################
## 6. 학습 ##
######################
import nemo
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
## step 1. 모델 config 파일 다운 받기
subprocess.run(f'mkdir {working_dir}/conf', shell=True)
subprocess.run(f'wget -P {working_dir}/conf https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/recognition/conf/titanet-large.yaml', shell=True)
MODEL_CONFIG = os.path.join(working_dir,f'{working_dir}/conf/titanet-large.yaml')
config = OmegaConf.load(MODEL_CONFIG)
print(OmegaConf.to_yaml(config))
## step 2. config파일 데이터셋 딕셔너리 확인
print(OmegaConf.to_yaml(config.model.train_ds))
print(OmegaConf.to_yaml(config.model.validation_ds))
## step 3. manifest_filepath(???)에 위의 경로 및 화자 수 설정하기
# 모델 훈련이므로 test dataset은 지금은 필요하지 않음
config.model.train_ds.manifest_filepath = train_manifest # train
config.model.validation_ds.manifest_filepath = validation_manifest # validation
config.model.decoder.num_classes = 3066 # 화자 수
## step 4. Pytorch Lightning로 트레이너 구축
import torch
import pytorch_lightning as pl # 공식문서에는 import lightning.pytorch as pl 이걸로 나와 있지만 오류.
torch.set_float32_matmul_precision('high') # medium:성능과 정밀도의 균형, high:정밀도 특화(느림).
# Trainer Config 확인하기
print("Trainer config - \n")
print(OmegaConf.to_yaml(config.trainer))
# Trainer Config 수정
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
config.trainer.devices = 1
config.trainer.accelerator = accelerator
# config.trainer.enable_checkpointing = True
config.trainer.max_epochs = 500
config.trainer.strategy = 'auto' # 분산 학습 설정(다중 GPU)
config.model.train_ds.augmentor=None # 데이터 증강
trainer = pl.Trainer(**config.trainer)
print(trainer)
## step 5. NeMo의 로깅 및 체크포인트 관리자 불러오기
from nemo.utils.exp_manager import exp_manager
exp_manager_config = config.get("exp_manager", {})
exp_manager_config['exp_dir'] = 'path_your_log' # 새로운 로그 경로 설정
exp_manager_config['name'] = '20250203' # 실험 이름 변경 (필요시)
exp_manager_config['version'] = '1' # 버전 변경 (필요시)
log_dir = exp_manager(trainer, exp_manager_config)
print(log_dir)
## step 6. TitaNet Model 구축하기
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel(cfg=config.model, trainer=trainer)
print(speaker_model)
## step 7. 학습
trainer.fit(speaker_model)
# trainer.test(speaker_model, ckpt_path=None) # 테스트 데이터셋으로 검증
#################################
## 7. 모델 불러오기 ##
#################################
final_checkpoint = 'path_your_checkpoint'
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.load_from_checkpoint(final_checkpoint)
print(speaker_model.summarize())
누군가에겐 도움이 되길.
728x90
'연구하기, 지식' 카테고리의 다른 글
쿨백-라이블러 발산(Kullback–Leibler divergence, KLD) (1) | 2025.05.06 |
---|---|
Stable Diffusion Basemodel 로컬에서 사용하기 (0) | 2025.05.06 |
Whisper vs Faster-Whisper : 성능 비교 (0) | 2025.02.03 |
NeMo speaker embedding model(TitaNet-L) FInetuning 코드 (0) | 2025.02.03 |
한국어 LLM 모델 별 요약 성능 (0) | 2025.02.03 |