| |
| |
| |
| import argparse |
| import torch |
| from torch import nn |
|
|
| import utils, data_setup, model, engine |
| import yaml |
|
|
| |
| utils.set_seed(0) |
|
|
| |
| parser = argparse.ArgumentParser(fromfile_prefix_chars = '@') |
|
|
| parser.add_argument('-nw', '--num-workers', help = 'Number of workers for dataloaders.', |
| type = int, default = 0) |
| parser.add_argument('-ne', '--num-epochs', help = 'Number of epochs to train model for.', |
| type = int, default = 25) |
| parser.add_argument('-bs', '--batch-size', help = 'Size of batches to split training set.', |
| type = int, default = 100) |
| parser.add_argument('-lr', '--learning-rate', help = 'Learning rate for the optimizer.', |
| type = float, default = 0.001) |
| parser.add_argument('-p', '--patience', help = 'Number of epochs to wait before early stopping.', |
| type = int, default = 10) |
| parser.add_argument('-md', '--min-delta', help = 'Minimum decrease in loss to reset patience.', |
| type = float, default = 0.001) |
|
|
| args = parser.parse_args() |
|
|
|
|
| |
| |
| |
| if __name__ == '__main__': |
|
|
| print(f'{'#' * 50}\n' |
| f'\033[1mTraining hyperparameters:\033[0m \n' |
| f' - num-workers: {args.num_workers} \n' |
| f' - num-epochs: {args.num_epochs} \n' |
| f' - batch-size: {args.batch_size} \n' |
| f' - learning-rate: {args.learning_rate} \n' |
| f' - patience: {args.patience} \n' |
| f' - min-delta: {args.min_delta} \n' |
| f'{'#' * 50}') |
|
|
| |
| train_dl, test_dl = data_setup.get_dataloaders(root = './mnist_data', |
| batch_size = args.batch_size, |
| num_workers = args.num_workers) |
| |
| |
| save_dir = '../saved_models' |
|
|
| base_name = 'tiny_vgg_less_compute' |
| mod_name = f'{base_name}_model.pth' |
|
|
| |
| mod_kwargs = { |
| 'num_blks': 2, |
| 'num_convs': 2, |
| 'in_channels': 1, |
| 'hidden_channels': 5, |
| 'fc_hidden_dim': 128, |
| 'num_classes': len(train_dl.dataset.classes) |
| } |
|
|
| vgg_mod = model.TinyVGG(**mod_kwargs).to(utils.DEVICE) |
| torch.compile(vgg_mod) |
|
|
| |
| with open(f'{save_dir}/{base_name}_settings.yaml', 'w') as f: |
| yaml.dump({'train_kwargs': vars(args), 'mod_kwargs': mod_kwargs}, f) |
|
|
| |
| loss_fn = nn.CrossEntropyLoss() |
| optimizer = torch.optim.Adam(params = vgg_mod.parameters(), lr = args.learning_rate) |
|
|
| |
| mod_res = engine.train(model = vgg_mod, |
| train_dl = train_dl, |
| test_dl = test_dl, |
| loss_fn = loss_fn, |
| optimizer = optimizer, |
| num_epochs = args.num_epochs, |
| patience = args.patience, |
| min_delta = args.min_delta, |
| device = utils.DEVICE, |
| save_mod = True, |
| save_dir = save_dir, |
| mod_name = mod_name) |
|
|