sol’s blog

모델 실험을 위한 코드 템플릿 생성

sol-commits
sol-commitsMay 17, 2025
모델 실험을 위한 코드 템플릿 생성
목적, 마감, 결과, subtasks

    config dataclass

    from dataclasses import dataclass
    
    @dataclass
    class Config:
        case_name: str = "LSTM" # 바꿔야하는 부분
        
        sequence_length: int = 12
    python
    💡
    꼭 확인해야 하는 부분
    • case_name → 모델 최상위 디렉토리명이 되는 값
    • sequence_length
      • 이 값에 따라 csv 를 자동으로 가져오므로 확인 필요

    early stopping, learning rate scheduler는 제외

    • 현재 실험의 목적은 모델 구조(ex. LSTM, BiLSTM, Attention 등) 간의 성능 차이를 공정하게 비교하는 것
    • early stopping, lr scheduler를 사용하면 공정성이 떨어짐
    • 구조 자체의 차이를 관찰하고 싶을 때는 학습률 변화를 고정한 상태로 실험하는 것이 더 객관적

    → 성능을 최대한 끌어내야 하는 단계에서 이 전략들을 다시 적용

    model_type, data_type

    모델 저장과 결과 저장하는 함수에 model_type, data_type 을 인자로 받아야하는 경우가 있음

    이를 enum으로 코드를 짜서 실수하지 않도록 방지

    from enum import Enum
    
    class ModelType(Enum):
      BEST = "best"
      LAST = "last"
    
    class DataType(Enum):
      TRAIN = "train"
      VAL = "val"
      TEST = "test"
    python
    💡
    model_type
    • ModelType.BEST
    • ModelType.LAST

    모델은 best_model, last_model 둘만 저장하기 때문에, 어떤 모델을 load 할지에 따라 model_type을 넘겨주면 됩니다.

    💡
    data_type
    • DataType.TRAIN
    • DataType.VAL
    • DataType.TEST

    어떤 데이터셋으로 모델 성능을 평가할지에 따라 data_type을 넘겨주면 됩니다.

    모델 성능 평가 및 오분류 시각화

    val_labels, val_preds, val_probs = evaluate_model(model, cfg.device, val_loader, data_type="validation")
    python
    • evaluate_model
      • 모델을 평가하고 평가 결과인 confusion_matrix, classification_report를 저장 경로에 저장, 시각화, (true_labels, pred_labels, 예측 확률) 반환
    • 클래스별 확신 높은 오분류 top-10 인덱스 추출
      이 인덱스 중 하나를 골라 다음 함수를 부름
      이 인덱스 중 하나를 골라 다음 함수를 부름

    최종 결과물

    config.case_name 디렉토리 하위에 모델 결과들이 저장됨

    classification_report, confusion_matrix 폴더를 분리하는 게 나을지…
    classification_report, confusion_matrix 폴더를 분리하는 게 나을지…
    best_train_classification_report.pngbest_model 로 train 성능 평가한 classification_report
    trainval_curve.pngloss, accuracy curve 그래프
    training_results.csv매 epoch 마다의 train, validation 의 loss, accuracy
    weightsbest_model.pth
    last_model.pth
    저장