CHAPTER 7. 에너지 기반 모델

2025. 3. 2. 16:45읽어보기, 교재/만들면서 배우는 생성 AI

에너지 기반 모델은 물리 시스템 모델링에서 핵심 아이디어를 차용한 광범위한 생성 모델 분야. 실숫값의 에너지 함수를 0과 1 사이로 정규화하는 함수인 볼츠만 분포로 어떤 이벤트의 확률을 표현할 수 있다는 것.

 

7.2 에너지 기반 모델(Energy Based Model)

  • 에너지 기반 모델은 볼츠만 분포를 사용해 실제 데이터 생성 분포를 모델링. => 0과 1 사이로 정규화.

볼츠만 분포

  • E(x)는 샘플 x의 에너지 함수(혹은 점수).
  • 신경망 E(x)를 훈련시켜 가능성 높은 샘플은 낮은 점수(0), 가능성 낮은 샘플은 높은 점수 출력(1).(에너지를 다 써야 좋은 샘플이다!)
  • 이러한 방식의 두가지 문제점.
    1. 점수가 낮은 샘플(그럴듯한 샘플)은 어떻게 생성할까?
    2. p(x)가 유효한 확률분포여야 하는데 분모의 적분이 어려움. 

=> 에너지 기반 모델의 핵심은 근사 기법 사용해 분모를 계산할 필요가 없음. => 노멀라이징 플로와 대조적.

  • 에너지 기반 모델의 아이디어 : (훈련을 위한) 대조 발산 기법, (샘플링을 위한) 랑주뱅 동역학 기법.

7.2.2 에너지 함수

  • 에너지 함수 E_ θ(x)
    • 파라미터가 θ이고 입력 이미지 x를 하나의 스칼라 값으로 반환
    • 스위시 활성화 함수 사용.
    • 신경망은 Conv2D 층 쌓아 특징 추출.
    • 마지막 층은 선형 활성화를 가진 하나의 완전 연결 유닛. => (-∞ ~ ∞) 범위를 출력.

* 스위시 활성화 함수 : ReLU의 대한으로 구글에서 제안.

7.2.3 랑주뱅 동역학을 사용해 샘플링하기

  • 에너지 함수는 입력에 대해 하나의 점수만 출력하는데, 이 함수를 사용해 에너지 점수가 낮은 새로운 샘플을 생성하려면? => 랑주뱅 동역학 기법 사용.
  • 랑주뱅 동역학
    • 입력에 대한 에너지 함수의 그레디언트 계산.
      1. 샘플 공간의 임의의 지점에서 시작.
      2. 계산된 그레디언트의 반대 방향으로 조금씩 이동하며 에너지 함수를 감소.
      3. 샘플 공간 이동할 때 소량의 랜덤한 잡음 추가.(local minimum 예방, 확률적 경사 랑주뱅 동역학)

이 경로는 입력 x에 대한 에너지 함수 E(x)의 음의 그레디언트를 따른 잡음이 있는 내리막길

  • 신경망 경사하강법과 차이점
    • 신경망 : 역전파 사용해 가중치의 손실 함수의 그레디언트를 계산. 그 다음 음의 그레디언트 방향으로 파라미터를 조금 업데이트해 점진적 손실 최소화.
    • 랑주뱅 동역학 : 신경망 가중치 고정하고 입력에 대한 출력의 그레디언트를 계산. 점진적으로 출력(에너지 점수)을 최소화.

7.2.4 대조 발산으로 훈련하기

  • 에너지 함수를 확률 출력하지 않아 최대 가능도 추정 적용 불가.

=> 제프리 힌튼이 제안한 대조 발산 기법 적용.

  • p(x)가 에너지 함수 E(x)를 포함하는 볼츠만 분포 형태 일 때 음의 로그 가능도를 최소화.

  • 실제 샘플에 대해서는 큰 음의 에너지 점수, 생성된 가짜 샘플에 대해서는 큰 양의 에너지 점수 출력. => 두 극단의 차이가 커지게(점수 차이가 손실 함수).
  • 가짜 샘플의 에너지 점수 계산 하려면 분포 p(x)에서 정확히 샘플링 해야됨.

=> 분모 계산 어려워서 불가능 하므로 랑주뱅 샘플링 방법 사용해 낮은 에너지 점수 가진 샘플 집합 생성.

  • 랑주뱅 샘플링 과정을 무한히 많이 실행 해야하지만 의미 있는 적은 스텝만 수행.
  • 이전 번복의 샘플을 버퍼에 저장해 다음 배치의 시작점으로 랜덤 잡음과 섞어서 사용. 

대조발산의 한 스텝

  • 진짜 샘플은 점수는 올라가고 가짜 샘플 점수는 내려감.

7.2.6 에너지 기반 모델 분석

  • 훈련 스텝에서 계산된 손실은 변화가 거의 없음. => 일반적이지 않음.
  • 대조 발산이 시간이 지남에 따라 실제 이미지와 랜덤한 잡음을 구별하는 능력이 향상.

7.2.6 기타 에너지 기반 모델

  • 초기 EBM 모델은 랑주뱅 샘플링을 사용하지 않음. => 볼츠만 머신.
  • 볼츠만 머신 : 훈련은 대조 발산을 통해 이루어지지만, 깁스 샘플링 이라는 방법으로 균형 찾음.

=> 훈련 속도 느리고 은닉 유닛의 개수를 크게 늘릴 수 없음.

  • 볼츠만 머신을 확장한 제한된 볼츠만 머신(RBM), RBM을 쌓은 심층 신뢰 신경망.

=> 여전히 성능 안좋음.

  • EBM이 확립되고 랑주뱅 동역학이 선호하는 샘플링 방법이 되고 점수매칭 훈련 기법으로 발전.

=> 잡음 제거 확산 확률 모델 모델로 발전.(DALL E 2, Imagen)


* 후에 깨달은 것

  • 에너지 함수를 사용하는 이유 : 샘플의 가능성을 나타낼 수 있는 좋은 지표이기 때문.
  • 볼츠만 분포를 사용하는 이유 : 에너지 함수의 범위는 -  ~ ∞ 라 활용이 불가(범위가 끝이 없어 뭐가 좋은건지 안좋은 건지 모름) 하므로 정규화를 위해 사용.
  • 랑주백 동역학을 사용하는 이유 : 에너지 함수를 볼츠만 분포로 만들 때 분모 계산이 너무 어려워 직접 계산하지 않고 근사하는 방법.
  • 대조 발산이 감소하는 이유
    • 대조 발산의 한 스텝에서 실제 샘플 점수는 내려가고 가짜 샘플 점수는 올라감.
      • 진짜와 가짜의 차이를 확실하게 구분.
    • 대조 발산 식은 다음과 같음.
cdiv_loss = tf.reduce_mean(fake_out, axis = 0) - tf.reduce_mean(real_out, axis = 0)

 

    • 언뜻 보면 가짜 샘플 점수가 올라가고 진짜 샘플 점수가 내려가면 양으로 발산할 것 같음.
    • 하지만 학습 할 수록 가짜 샘플은 더욱 진짜 같아 지고 진짜 샘플은 상대적으로 덜 진짜 같아 보이니 감소.
  • 대조 발산은 정규화 되지 않은 점수를 출력하는 모델이므로 점수 정규화 X. => 하지만 학습 시 loss 값은 alpha 값을 곱해 정규화.
reg_loss = self.alpha * tf.reduce_mean(
	real_out ** 2 + fake_out ** 2, axis = 0
)
728x90