Ejemplo de cómo usar Ignite
/ᐠ。‸。ᐟ\ Una librería de alto nivel que facilita el entrenamiento y la evaluación de redes neuronales en PyTorch
Recientemente han salido varias librerías de alto nivel para PyTorch. Una vez planteados los datasets y los modelos todavía queda bastante por programar, sobre todo los bucles de entrenamiento y validación, código que salvando detalles de implementación es siempre el mismo (boilerplate). Librerías como PyTorch Lighting y PyTorch Ignite prometen ahorrarnos el código repetitivo y concentrarnos en lo particular.
Se instala con
pip install pytorch-ignite
import torch
from sklearn.metrics import balanced_accuracy_score
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
Definimos el triángulo modelo-optimizador-criterio. Lo de siempre.
modelo = Modelo().to(device)
optimizador = torch.optim.Adam(modelo.parameters(), lr=1e-3, weight_decay=1e-5)
criterio = torch.nn.CrossEntropyLoss()
Bucle de entrenamiento.
def entrenar(motor, lote):
modelo.train()
optimizador.zero_grad()
predicciones = modelo(lote.documentos.to(device))
pérdida = criterio(predicciones, lote.etiquetas.to(device))
pérdida.backward()
optimizador.step()
return pérdida.item()
entrenador = Engine(entrenar)
Bucle de evaluación. Sirve para todo lo que no es entrenamiento, a ser validación e inferencia.
def evaluar(motor, lote):
modelo.eval()
with torch.no_grad():
predicciones = modelo(lote.documentos.to(device))
# validación
if lote.etiquetas is not None:
return predicciones, lote.etiquetas.to(device)
# inferencia
return predicciones
evaluador = Engine(evaluar)
inferidor = Engine(evaluar)
Adjuntamos algunas métricas al evaluador. Notar que entrenar
devuelve el valor de la función de pérdida mientras que evaluar
devuelve predicciones.
Accuracy().attach(evaluador, 'accuracy')
Loss(criterio).attach(evaluador, 'loss')
- Evento: al completar una época de entrenamiento.
-
Acción: registrar métricas del dataset de entrenamiento. Para ello usamos
evaluador.run(train_dl)
.
@entrenador.on(Events.EPOCH_COMPLETED)
def loguear_resultados_entrenamiento(entrenador):
evaluador.run(train_dl)
# accedemos a este atributo gracias a haber adjuntado métricas previamente
métricas = evaluador.state.metrics
print(f"[{entrenador.state.epoch:02}] TRAIN Accuracy: {métricas['accuracy']:.2f} Loss: {métricas['loss']:.2f}")
- Evento: al completar una época de entrenamiento. Acá podría ser al completar X épocas para no validar tan seguido.
-
Acción: registrar métricas del dataset de validación. Para ello usamos
evaluador.run(valid_dl)
.
Esta función podría hacer sido más parecida a la de arriba pero no lo es porque queremos calcular una métrica que no viene con Ignite. Tenemos dos opciones, definir una métrica adjuntable como ignite.metrics.Accuracy
—que no lo hicimos y quizás hubiese sido lo mejor— o poner la lógica en la función, como vemos aquí.
@entrenador.on(Events.EPOCH_COMPLETED)
def loguear_resultados_validación(entrenador):
# esta artimaña tendrá sentido más adelante
evaluador.predicciones = []
evaluador.etiquetas = []
evaluador.run(valid_dl)
# de todas las categorías nos quedamos con la más probable para cada muestra
predicciones = torch.cat(evaluador.predicciones).argmax(dim=1).cpu()
etiquetas = torch.cat(evaluador.etiquetas).cpu()
score = balanced_accuracy_score(etiquetas, predicciones)
métricas = evaluador.state.metrics
print(f"[{entrenador.state.epoch:02}] VALID Accuracy: {métricas['accuracy']:.2f} Loss: {métricas['loss']:.2f} Balanced accuracy: {score:.2f}")
- Evento: al procesar un lote de validación.
- Acción: almacenar predicciones y etiquetas.
Esto le da sentido a la artimaña que mencionamos. La misma sirve para instanciar listas vacías al inicio de la validación, a las que se le agregaran los resultados de cada lote.
@evaluador.on(Events.ITERATION_COMPLETED)
def colectar_validaciones_lote(evaluador):
evaluador.predicciones.append(evaluador.state.output[0])
evaluador.etiquetas.append(evaluador.state.output[1])
- Evento: al completar el entrenamiento (todas las épocas).
-
Acción: realizar inferencias. Para ello usamos
inferidor.run(infer_dl)
.
@entrenador.on(Events.COMPLETED)
def colectar_inferencias(entrenador):
print('Realizando inferencias...')
# mismo truco de antes
inferidor.y_pred = []
inferidor.run(infer_dl)
# de todas las categorías nos quedamos con la más probable para cada muestra
y_pred = torch.cat(inferidor.y_pred).argmax(dim=1).reshape(-1,1)
# quizás sea un buen momento para recuperar las categorías originales
#y_pred = vocabulario_etiquetas.índices_a_tókenes(y_pred)
# ya que estamos, guardamos los resultados en un CSV
pd.DataFrame(y_pred).to_csv('submit.csv', header=False)
- Evento: al procesar un lote de inferencia.
- Acción: almacenar predicciones.
@inferidor.on(Events.ITERATION_COMPLETED)
def colectar_inferencias_lote(inferidor):
inferidor.y_pred.append(inferidor.state.output)
Finalmente largamos el entrenamiento con entrenador.run(train_dl)
. Este es el engranaje principal del mecanismo, que al completar bucles moverá a los otros engranajes (Engine
s, que no son literalmente enganajes pero puede que sea una buena metáfora).
entrenador.run(train_dl, max_epochs=5)