Mini-ImageNet / src /utils /checkpoint_manager.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
import os
import torch
def save_checkpoint(
path,
encoder,
decoder,
optimizer,
epoch,
train_loss,
val_loss
):
torch.save({
"epoch": epoch,
"encoder_state_dict": encoder.state_dict(),
"decoder_state_dict": decoder.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"train_loss": train_loss,
"val_loss": val_loss
}, path)
def load_checkpoint(
best_path,
encoder,
decoder,
optimizer,
device
):
print(f"Loading checkpoint: {best_path}")
checkpoint = torch.load(
best_path,
map_location=device
)
encoder.load_state_dict(checkpoint["encoder_state_dict"])
decoder.load_state_dict(checkpoint["decoder_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
best_val_loss = checkpoint["val_loss"]
print(
f"Resume from Epoch {start_epoch+1} | "
f"Best Val Loss: {best_val_loss:.4f}"
)
return start_epoch+1, best_val_loss