NeMo speaker embedding model(TitaNet-L) 학습 코드

2025. 2. 3. 16:08연구하기, 지식

  1. 공식문서에는 jupyter notebook 기준으로 나와 있지만 py 환경에 맞게 재구성.
  2. 각종 명령어들은 subprocess를 사용하여 실행.
  3. 자세한 데이터 구조는 an4를 직접 다운로드 후 참고.
import os
import glob
import subprocess
import tarfile
import wget

########################################
##      1. an4 실험 데이터 다운로드     ##
########################################
##  *** 데이터 구조 ***
# ./data/an4/wav/an4clstk/화자별 폴더/음성데이터 13개(모두가 똑같은 스크립트를 읽음)

data_dir = os.path.join('path_your_data_home')
train_speakers_folder_dir = 'your_folder_name'
test_speakers_folder_dir = 'your_folder_name'
# an4_train_speakers_folder_dir = 'path_your_an4'

# os.makedirs(data_dir, exist_ok=True)

# # Download the dataset. This will take a few moments...
# print("******")
# if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):
#     an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'
#     an4_path = wget.download(an4_url, data_dir)
#     print(f"Dataset downloaded at: {an4_path}")
# else:
#     print("Tarfile already exists.")
#     an4_path = data_dir + '/an4_sphere.tar.gz'

# # Untar and convert .sph to .wav (using sox)
# tar = tarfile.open(an4_path)
# tar.extractall(path=data_dir)

# print("Converting .sph to .wav...")
# sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)
# for sph_path in sph_list:
#     wav_path = sph_path[:-4] + '.wav'
#     cmd = ["sox", sph_path, wav_path]
#     subprocess.run(cmd)
# print("Finished conversion.\n******")

#########################################
##      2. train txt 파일 만들기        ##
#########################################
command = f"find {data_dir}/{train_speakers_folder_dir} -iname '*.wav' > {data_dir}/{train_speakers_folder_dir}/train_all.txt"
subprocess.run(command, shell=True)
command = f'head -n 3 {data_dir}/{train_speakers_folder_dir}/train_all.txt'
subprocess.run(command, shell=True)

########################################################################
##      3. manifast 제작용 py 파일 만들기 및 train manifest 만들기      ##
########################################################################
working_dir = 'path_your_working'
# if not os.path.exists(f'{working_dir}/scripts'):
#   print("Downloading necessary scripts")
#   subprocess.run(f'mkdir -p {working_dir}/scripts/speaker_tasks', shell=True)
#   subprocess.run(f'wget -P {working_dir}/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/speaker_tasks/filelist_to_manifest.py', shell=True)
# 직접 다운로드 : https://github.com/NVIDIA/NeMo/blob/main/scripts/speaker_tasks/filelist_to_manifest.py
subprocess.run(f'python {working_dir}/scripts/speaker_tasks/filelist_to_manifest.py --filelist {data_dir}/{train_speakers_folder_dir}/train_all.txt --id -2 --out {data_dir}/{train_speakers_folder_dir}/all_manifest.json --split', shell=True)

##############################################
##      4. test txt 및 manifest 만들기      ##
##############################################
subprocess.run(f'find {data_dir}/{test_speakers_folder_dir}  -iname "*.wav" > {data_dir}/{test_speakers_folder_dir}/test_all.txt', shell=True)
subprocess.run(f'python {working_dir}/scripts/speaker_tasks/filelist_to_manifest.py --filelist {data_dir}/{test_speakers_folder_dir}/test_all.txt --id -2 --out {data_dir}/an4/wav/an4test_clstk/test.json', shell=True)

#########################################
##      5. manifest 파일 경로 설정      ##
#########################################
train_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/train.json')
validation_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/dev.json')
test_manifest = os.path.join(data_dir,f'{train_speakers_folder_dir}/dev.json')

######################
##      6. 학습     ##
######################
import nemo
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf

## step 1. 모델 config 파일 다운 받기
subprocess.run(f'mkdir {working_dir}/conf', shell=True)
subprocess.run(f'wget -P {working_dir}/conf https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/recognition/conf/titanet-large.yaml', shell=True)
MODEL_CONFIG = os.path.join(working_dir,f'{working_dir}/conf/titanet-large.yaml')
config = OmegaConf.load(MODEL_CONFIG)
print(OmegaConf.to_yaml(config))

## step 2. config파일 데이터셋 딕셔너리 확인
print(OmegaConf.to_yaml(config.model.train_ds))
print(OmegaConf.to_yaml(config.model.validation_ds))

## step 3. manifest_filepath(???)에 위의 경로 및 화자 수 설정하기
# 모델 훈련이므로 test dataset은 지금은 필요하지 않음
config.model.train_ds.manifest_filepath = train_manifest                # train
config.model.validation_ds.manifest_filepath = validation_manifest      # validation
config.model.decoder.num_classes = 3066                                   # 화자 수

## step 4. Pytorch Lightning로 트레이너 구축
import torch
import pytorch_lightning as pl # 공식문서에는 import lightning.pytorch as pl 이걸로 나와 있지만 오류.
torch.set_float32_matmul_precision('high') # medium:성능과 정밀도의 균형, high:정밀도 특화(느림).
# Trainer Config 확인하기
print("Trainer config - \n")
print(OmegaConf.to_yaml(config.trainer))

# Trainer Config 수정
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
config.trainer.devices = 1
config.trainer.accelerator = accelerator
# config.trainer.enable_checkpointing = True
config.trainer.max_epochs = 500
config.trainer.strategy = 'auto'        # 분산 학습 설정(다중 GPU)
config.model.train_ds.augmentor=None    # 데이터 증강
trainer = pl.Trainer(**config.trainer)
print(trainer)

## step 5. NeMo의 로깅 및 체크포인트 관리자 불러오기
from nemo.utils.exp_manager import exp_manager

exp_manager_config = config.get("exp_manager", {})
exp_manager_config['exp_dir'] = 'path_your_log'     # 새로운 로그 경로 설정
exp_manager_config['name'] = '20250203'				# 실험 이름 변경 (필요시)
exp_manager_config['version'] = '1'					# 버전 변경 (필요시)

log_dir = exp_manager(trainer, exp_manager_config)
print(log_dir)

## step 6. TitaNet Model 구축하기
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel(cfg=config.model, trainer=trainer)
print(speaker_model)

## step 7. 학습
trainer.fit(speaker_model)
# trainer.test(speaker_model, ckpt_path=None) # 테스트 데이터셋으로 검증

#################################
##      7. 모델 불러오기        ##
#################################
final_checkpoint = 'path_your_checkpoint'
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.load_from_checkpoint(final_checkpoint)
print(speaker_model.summarize())

누군가에겐 도움이 되길.

728x90