Spaces:
Running
Running
| import os | |
| import torch | |
| def save_checkpoint( | |
| path, | |
| encoder, | |
| decoder, | |
| optimizer, | |
| epoch, | |
| train_loss, | |
| val_loss | |
| ): | |
| torch.save({ | |
| "epoch": epoch, | |
| "encoder_state_dict": encoder.state_dict(), | |
| "decoder_state_dict": decoder.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "train_loss": train_loss, | |
| "val_loss": val_loss | |
| }, path) | |
| def load_checkpoint( | |
| best_path, | |
| encoder, | |
| decoder, | |
| optimizer, | |
| device | |
| ): | |
| print(f"Loading checkpoint: {best_path}") | |
| checkpoint = torch.load( | |
| best_path, | |
| map_location=device | |
| ) | |
| encoder.load_state_dict(checkpoint["encoder_state_dict"]) | |
| decoder.load_state_dict(checkpoint["decoder_state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| start_epoch = checkpoint["epoch"] | |
| best_val_loss = checkpoint["val_loss"] | |
| print( | |
| f"Resume from Epoch {start_epoch+1} | " | |
| f"Best Val Loss: {best_val_loss:.4f}" | |
| ) | |
| return start_epoch+1, best_val_loss |