Recientemente han salido varias librerías de alto nivel para PyTorch. Una vez planteados los datasets y los modelos todavía queda bastante por programar, sobre todo los bucles de entrenamiento y validación, código que salvando detalles de implementación es siempre el mismo (boilerplate). Librerías como PyTorch Lighting y PyTorch Ignite prometen ahorrarnos el código repetitivo y concentrarnos en lo particular.

Se instala con

pip install pytorch-ignite
import torch

from sklearn.metrics import balanced_accuracy_score

from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

Definimos el triángulo modelo-optimizador-criterio. Lo de siempre.

modelo = Modelo().to(device)

optimizador = torch.optim.Adam(modelo.parameters(), lr=1e-3, weight_decay=1e-5)

criterio = torch.nn.CrossEntropyLoss()

Bucle de entrenamiento.

def entrenar(motor, lote):
    modelo.train()

    optimizador.zero_grad()
    
    predicciones = modelo(lote.documentos.to(device))
    pérdida = criterio(predicciones, lote.etiquetas.to(device))
    
    pérdida.backward()
    optimizador.step()
    
    return pérdida.item()


entrenador = Engine(entrenar)

Bucle de evaluación. Sirve para todo lo que no es entrenamiento, a ser validación e inferencia.

def evaluar(motor, lote):
    modelo.eval()

    with torch.no_grad():
        predicciones = modelo(lote.documentos.to(device))
    
    # validación
    if lote.etiquetas is not None:
        return predicciones, lote.etiquetas.to(device)
    
    # inferencia
    return predicciones


evaluador = Engine(evaluar)
inferidor = Engine(evaluar)

Adjuntamos algunas métricas al evaluador. Notar que entrenar devuelve el valor de la función de pérdida mientras que evaluar devuelve predicciones.

Accuracy().attach(evaluador, 'accuracy')
Loss(criterio).attach(evaluador, 'loss')
  • Evento: al completar una época de entrenamiento.
  • Acción: registrar métricas del dataset de entrenamiento. Para ello usamos evaluador.run(train_dl).
@entrenador.on(Events.EPOCH_COMPLETED)
def loguear_resultados_entrenamiento(entrenador):
    evaluador.run(train_dl)
    
    # accedemos a este atributo gracias a haber adjuntado métricas previamente
    métricas = evaluador.state.metrics
    
    print(f"[{entrenador.state.epoch:02}] TRAIN  Accuracy: {métricas['accuracy']:.2f}  Loss: {métricas['loss']:.2f}")
  • Evento: al completar una época de entrenamiento. Acá podría ser al completar X épocas para no validar tan seguido.
  • Acción: registrar métricas del dataset de validación. Para ello usamos evaluador.run(valid_dl).

Esta función podría hacer sido más parecida a la de arriba pero no lo es porque queremos calcular una métrica que no viene con Ignite. Tenemos dos opciones, definir una métrica adjuntable como ignite.metrics.Accuracy —que no lo hicimos y quizás hubiese sido lo mejor— o poner la lógica en la función, como vemos aquí.

@entrenador.on(Events.EPOCH_COMPLETED)
def loguear_resultados_validación(entrenador):
    # esta artimaña tendrá sentido más adelante
    evaluador.predicciones = []
    evaluador.etiquetas = []
    
    evaluador.run(valid_dl)
    
    # de todas las categorías nos quedamos con la más probable para cada muestra
    predicciones = torch.cat(evaluador.predicciones).argmax(dim=1).cpu()
    etiquetas    = torch.cat(evaluador.etiquetas).cpu()
    
    score = balanced_accuracy_score(etiquetas, predicciones)
    
    métricas = evaluador.state.metrics
    
    print(f"[{entrenador.state.epoch:02}] VALID  Accuracy: {métricas['accuracy']:.2f}  Loss: {métricas['loss']:.2f}  Balanced accuracy: {score:.2f}")
  • Evento: al procesar un lote de validación.
  • Acción: almacenar predicciones y etiquetas.

Esto le da sentido a la artimaña que mencionamos. La misma sirve para instanciar listas vacías al inicio de la validación, a las que se le agregaran los resultados de cada lote.

@evaluador.on(Events.ITERATION_COMPLETED)
def colectar_validaciones_lote(evaluador):
    evaluador.predicciones.append(evaluador.state.output[0])
    evaluador.etiquetas.append(evaluador.state.output[1])
  • Evento: al completar el entrenamiento (todas las épocas).
  • Acción: realizar inferencias. Para ello usamos inferidor.run(infer_dl).
@entrenador.on(Events.COMPLETED)
def colectar_inferencias(entrenador):
    print('Realizando inferencias...')
    
    # mismo truco de antes
    inferidor.y_pred = []
    
    inferidor.run(infer_dl)
    
    # de todas las categorías nos quedamos con la más probable para cada muestra
    y_pred = torch.cat(inferidor.y_pred).argmax(dim=1).reshape(-1,1)
    
    # quizás sea un buen momento para recuperar las categorías originales
    #y_pred = vocabulario_etiquetas.índices_a_tókenes(y_pred)
    
    # ya que estamos, guardamos los resultados en un CSV
    pd.DataFrame(y_pred).to_csv('submit.csv', header=False)
  • Evento: al procesar un lote de inferencia.
  • Acción: almacenar predicciones.
@inferidor.on(Events.ITERATION_COMPLETED)
def colectar_inferencias_lote(inferidor):
    inferidor.y_pred.append(inferidor.state.output)

Finalmente largamos el entrenamiento con entrenador.run(train_dl). Este es el engranaje principal del mecanismo, que al completar bucles moverá a los otros engranajes (Engines, que no son literalmente enganajes pero puede que sea una buena metáfora).

entrenador.run(train_dl, max_epochs=5)
[01] VALID  Accuracy: 0.80  Loss: 1.11  Balanced accuracy: 0.75
[01] TRAIN  Accuracy: 0.98  Loss: 0.02
[02] VALID  Accuracy: 0.80  Loss: 1.11  Balanced accuracy: 0.75
[02] TRAIN  Accuracy: 0.98  Loss: 0.02
[03] VALID  Accuracy: 0.80  Loss: 1.11  Balanced accuracy: 0.75
[03] TRAIN  Accuracy: 0.98  Loss: 0.02
[04] VALID  Accuracy: 0.80  Loss: 1.11  Balanced accuracy: 0.75
[04] TRAIN  Accuracy: 0.98  Loss: 0.02
[05] VALID  Accuracy: 0.80  Loss: 1.11  Balanced accuracy: 0.75
[05] TRAIN  Accuracy: 0.98  Loss: 0.02
Realizando inferencias...