NeMo speaker embedding model(TitaNet-L) FInetuning 코드

2025. 2. 3. 16:16연구하기, 지식

2025.02.03 - [R&D] - NeMo speaker embedding model(TitaNet-L) 학습 코드

NeMo speaker embedding model(TitaNet-L) 학습 코드

공식문서에는 jupyter notebook 기준으로 나와 있지만 py 환경에 맞게 재구성.각종 명령어들은 subprocess를 사용하여 실행.자세한 데이터 구조는 an4를 직접 다운로드 후 참고.import osimport globimport subproce

meerkat-developer.tistory.com

위 학습을 무조건 선행 해보고 진행!

  1. 공식문서에는 jupyter notebook 기준으로 나와 있지만 py 환경에 맞게 재구성.
  2. 각종 명령어들은 subprocess를 사용하여 실행.
  3. 자세한 데이터 구조는 an4를 직접 다운로드 후 참고.

import os
import nemo
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf
import torch
import pytorch_lightning as pl
from nemo.utils.exp_manager import exp_manager

#################################################
##      1. Finetune용 yaml 파일 다운로드        ##
#################################################
# 직접 다운로드
# https://github.com/NVIDIA/NeMo/blob/main/examples/speaker_tasks/recognition/conf/titanet-finetune.yaml
working_dir = 'your_dir'
MODEL_CONFIG = os.path.join(working_dir,'conf/titanet-finetune.yaml')
finetune_config = OmegaConf.load(MODEL_CONFIG)
print(OmegaConf.to_yaml(finetune_config))
# 체크포인트를 Finetuning 하려면 yaml 파일에 init_from_pretrained_model 이 부분을 init_from_nemo_model 이걸로 교체

#############################################
##      2. Finetune용 데이터 불러오기       ##
#############################################
# 선행 되어야 할 작업 : 'NeMo speaker embedding model(TitaNet-L) 학습 코드'의 4번까지 완료
data_dir = os.path.join('your_dir')
json_dir = 'your_dir'
test_manifest = os.path.join(data_dir,json_dir)
finetune_config.model.train_ds.manifest_filepath = test_manifest
finetune_config.model.validation_ds.manifest_filepath = test_manifest
finetune_config.model.decoder.num_classes = 3066        # 화자 수

######################################
##      3. trainer 객체 생성        ##
######################################
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'

trainer_config = OmegaConf.create(dict(
    devices=1,
    accelerator=accelerator,
    max_epochs=5,
    max_steps=-1,
    num_nodes=1,
    accumulate_grad_batches=1,
    enable_checkpointing=False,
    logger=False,
    log_every_n_steps=1,
    val_check_interval=1.0,
))
print(OmegaConf.to_yaml(trainer_config))
trainer_finetune = pl.Trainer(**trainer_config)

###########################################################
##      4. NeMo의 로깅 및 체크포인트 관리자 불러오기        ##
###########################################################
exp_manager_config = finetune_config.get("exp_manager", None)
exp_manager_config['exp_dir'] = 'your_dir'
exp_manager_config['name'] = '20250203_finetune'

log_dir_finetune = exp_manager(trainer_finetune, exp_manager_config)
print(log_dir_finetune)

##################################
##      5. Finetuning 하기      ##
##################################
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel(cfg=finetune_config.model, trainer=trainer_finetune)
speaker_model.maybe_init_from_pretrained_checkpoint(finetune_config)

trainer_finetune.fit(speaker_model)
728x90