| from tqdm import tqdm |
| import network |
| import utils |
| import os |
| import random |
| import argparse |
| import numpy as np |
| import json |
|
|
| from torch.utils import data |
| from datasets import VOCSegmentation, Cityscapes |
| from utils import ext_transforms as et |
| from metrics import StreamSegMetrics |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from PIL import Image |
| import matplotlib |
| import matplotlib.pyplot as plt |
|
|
|
|
| def get_argparser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--out_dir", type=str, default="run_0") |
|
|
| |
| parser.add_argument("--data_root", type=str, default='', |
| help="path to Dataset") |
| parser.add_argument("--dataset", type=str, default='voc', |
| choices=['voc'], help='Name of dataset') |
| parser.add_argument("--num_classes", type=int, default=None, |
| help="num classes (default: None)") |
|
|
| |
| parser.add_argument("--model", type=str, default='deeplabv3plus_resnet101', |
| choices=['deeplabv3plus_resnet101'], help='model name') |
| parser.add_argument("--separable_conv", action='store_true', default=False, |
| help="apply separable conv to decoder and aspp") |
| parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16]) |
|
|
| |
| parser.add_argument("--test_only", action='store_true', default=False) |
| parser.add_argument("--save_val_results", action='store_true', default=False, |
| help="save segmentation results to \"./results\"") |
| parser.add_argument("--total_itrs", type=int, default=30e3, |
| help="epoch number (default: 30k 30e3)") |
| parser.add_argument("--lr", type=float, default=0.02, |
| help="learning rate (default: 0.01)") |
| parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'], |
| help="learning rate scheduler policy") |
| parser.add_argument("--step_size", type=int, default=10000) |
| parser.add_argument("--crop_val", action='store_true', default=False, |
| help='crop validation (default: False)') |
| parser.add_argument("--batch_size", type=int, default=32, |
| help='batch size (default: 16)') |
| parser.add_argument("--val_batch_size", type=int, default=4, |
| help='batch size for validation (default: 4)') |
| parser.add_argument("--crop_size", type=int, default=513) |
|
|
| parser.add_argument("--ckpt", default=None, type=str, |
| help="restore from checkpoint") |
| parser.add_argument("--continue_training", action='store_true', default=False) |
|
|
| parser.add_argument("--loss_type", type=str, default='cross_entropy', |
| choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)") |
| parser.add_argument("--gpu_id", type=str, default='0,1', |
| help="GPU ID") |
| parser.add_argument("--weight_decay", type=float, default=1e-4, |
| help='weight decay (default: 1e-4)') |
| parser.add_argument("--random_seed", type=int, default=1, |
| help="random seed (default: 1)") |
| parser.add_argument("--print_interval", type=int, default=10, |
| help="print interval of loss (default: 10)") |
| parser.add_argument("--val_interval", type=int, default=100, |
| help="epoch interval for eval (default: 100)") |
| parser.add_argument("--download", action='store_true', default=False, |
| help="download datasets") |
|
|
| |
| parser.add_argument("--year", type=str, default='2012_aug', |
| choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC') |
| return parser |
|
|
|
|
| def get_dataset(opts): |
| """ Dataset And Augmentation |
| """ |
| if opts.dataset == 'voc': |
| train_transform = et.ExtCompose([ |
| |
| et.ExtRandomScale((0.5, 2.0)), |
| et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True), |
| et.ExtRandomHorizontalFlip(), |
| et.ExtToTensor(), |
| et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| if opts.crop_val: |
| val_transform = et.ExtCompose([ |
| et.ExtResize(opts.crop_size), |
| et.ExtCenterCrop(opts.crop_size), |
| et.ExtToTensor(), |
| et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| else: |
| val_transform = et.ExtCompose([ |
| et.ExtToTensor(), |
| et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
| image_set='train', download=opts.download, transform=train_transform) |
| val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, |
| image_set='val', download=False, transform=val_transform) |
|
|
| if opts.dataset == 'cityscapes': |
| train_transform = et.ExtCompose([ |
| |
| et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)), |
| et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), |
| et.ExtRandomHorizontalFlip(), |
| et.ExtToTensor(), |
| et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| val_transform = et.ExtCompose([ |
| |
| et.ExtToTensor(), |
| et.ExtNormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| train_dst = Cityscapes(root=opts.data_root, |
| split='train', transform=train_transform) |
| val_dst = Cityscapes(root=opts.data_root, |
| split='val', transform=val_transform) |
| return train_dst, val_dst |
|
|
|
|
| def validate(opts, model, loader, device, metrics, ret_samples_ids=None): |
| """Do validation and return specified samples""" |
| metrics.reset() |
| ret_samples = [] |
| if opts.save_val_results: |
| if not os.path.exists('results'): |
| os.mkdir('results') |
| denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| img_id = 0 |
|
|
| with torch.no_grad(): |
| for i, (images, labels) in tqdm(enumerate(loader)): |
|
|
| images = images.to(device, dtype=torch.float32) |
| labels = labels.to(device, dtype=torch.long) |
|
|
| outputs = model(images) |
| preds = outputs.detach().max(dim=1)[1].cpu().numpy() |
| targets = labels.cpu().numpy() |
|
|
| metrics.update(targets, preds) |
| if ret_samples_ids is not None and i in ret_samples_ids: |
| ret_samples.append( |
| (images[0].detach().cpu().numpy(), targets[0], preds[0])) |
|
|
| if opts.save_val_results: |
| for i in range(len(images)): |
| image = images[i].detach().cpu().numpy() |
| target = targets[i] |
| pred = preds[i] |
|
|
| image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8) |
| target = loader.dataset.decode_target(target).astype(np.uint8) |
| pred = loader.dataset.decode_target(pred).astype(np.uint8) |
|
|
| Image.fromarray(image).save('results/%d_image.png' % img_id) |
| Image.fromarray(target).save('results/%d_target.png' % img_id) |
| Image.fromarray(pred).save('results/%d_pred.png' % img_id) |
|
|
| fig = plt.figure() |
| plt.imshow(image) |
| plt.axis('off') |
| plt.imshow(pred, alpha=0.7) |
| ax = plt.gca() |
| ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
| ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator()) |
| plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0) |
| plt.close() |
| img_id += 1 |
|
|
| score = metrics.get_results() |
| return score, ret_samples |
|
|
| def main(opts): |
| if opts.dataset.lower() == 'voc': |
| opts.num_classes = 21 |
| elif opts.dataset.lower() == 'cityscapes': |
| opts.num_classes = 19 |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print("Device: %s" % device) |
|
|
| |
| torch.manual_seed(opts.random_seed) |
| np.random.seed(opts.random_seed) |
| random.seed(opts.random_seed) |
| |
| |
| writer = SummaryWriter(log_dir='logs') |
| |
| |
| if opts.dataset == 'voc' and not opts.crop_val: |
| opts.val_batch_size = 1 |
|
|
| train_dst, val_dst = get_dataset(opts) |
| train_loader = data.DataLoader( |
| train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2, |
| drop_last=True) |
| val_loader = data.DataLoader( |
| val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) |
| print("Dataset: %s, Train set: %d, Val set: %d" % |
| (opts.dataset, len(train_dst), len(val_dst))) |
|
|
| |
| model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride) |
| if opts.separable_conv and 'plus' in opts.model: |
| network.convert_to_separable_conv(model.classifier) |
| utils.set_bn_momentum(model.backbone, momentum=0.01) |
|
|
| |
| metrics = StreamSegMetrics(opts.num_classes) |
|
|
| |
| optimizer = torch.optim.SGD(params=[ |
| {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr}, |
| {'params': model.classifier.parameters(), 'lr': opts.lr}, |
| ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay) |
| |
| |
| if opts.lr_policy == 'poly': |
| scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9) |
| elif opts.lr_policy == 'step': |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1) |
|
|
| |
| |
| if opts.loss_type == 'focal_loss': |
| criterion = utils.FocalLoss(ignore_index=255, size_average=True) |
| elif opts.loss_type == 'cross_entropy': |
| criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') |
|
|
| def save_ckpt(path): |
| """ save current model |
| """ |
| torch.save({ |
| "cur_itrs": cur_itrs, |
| "model_state": model.module.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "scheduler_state": scheduler.state_dict(), |
| "best_score": best_score, |
| }, path) |
| print("Model saved as %s" % path) |
| |
| if not os.path.exists('checkpoints'): |
| os.mkdir('checkpoints') |
| |
| |
| best_score = 0.0 |
| cur_itrs = 0 |
| cur_epochs = 0 |
| |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| if opts.ckpt is not None and os.path.isfile(opts.ckpt): |
| |
| checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu')) |
| model.load_state_dict(checkpoint["model_state"]) |
| model = nn.DataParallel(model) |
| model.to(device) |
| if opts.continue_training: |
| optimizer.load_state_dict(checkpoint["optimizer_state"]) |
| scheduler.load_state_dict(checkpoint["scheduler_state"]) |
| cur_itrs = checkpoint["cur_itrs"] |
| best_score = checkpoint['best_score'] |
| print("Training state restored from %s" % opts.ckpt) |
| print("Model restored from %s" % opts.ckpt) |
| del checkpoint |
| else: |
| print("[!] Retrain") |
| model = nn.DataParallel(model) |
| model.to(device) |
|
|
| |
| denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
| if opts.test_only: |
| model.eval() |
| val_score, ret_samples = validate( |
| opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
| print(metrics.to_str(val_score)) |
| writer.close() |
| return |
|
|
| interval_loss = 0 |
| latest_checkpoints = [] |
| if not os.path.exists(f'checkpoints'): |
| os.mkdir(f'checkpoints') |
| while True: |
| |
| model.train() |
| cur_epochs += 1 |
| for (images, labels) in train_loader: |
| cur_itrs += 1 |
|
|
| images = images.to(device, dtype=torch.float32) |
| labels = labels.to(device, dtype=torch.long) |
|
|
| optimizer.zero_grad() |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
|
|
| np_loss = loss.detach().cpu().numpy() |
| interval_loss += np_loss |
| |
| writer.add_scalar('Loss/train', np_loss, cur_itrs) |
|
|
| if (cur_itrs) % 10 == 0: |
| interval_loss = interval_loss / 10 |
| print("Epoch %d, Itrs %d/%d, Loss=%f" % |
| (cur_epochs, cur_itrs, opts.total_itrs, interval_loss)) |
| interval_loss = 0.0 |
|
|
| if (cur_itrs) % opts.val_interval == 0: |
| ckpt_path = f'checkpoints/latest_{cur_itrs}_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth' |
| save_ckpt(ckpt_path) |
| latest_checkpoints.append(ckpt_path) |
| |
| if len(latest_checkpoints) > 2: |
| |
| oldest_ckpt_path = latest_checkpoints.pop(0) |
| try: |
| |
| os.remove(oldest_ckpt_path) |
| print(f"Successfully removed old checkpoint: {oldest_ckpt_path}") |
| except FileNotFoundError: |
| |
| print(f"Warning: Could not remove checkpoint because it was not found: {oldest_ckpt_path}") |
| except OSError as e: |
| |
| print(f"Error removing checkpoint {oldest_ckpt_path}: {e}") |
| |
| print("validation...") |
| model.eval() |
| val_score, ret_samples = validate( |
| opts=opts, model=model, loader=val_loader, device=device, metrics=metrics) |
| print(metrics.to_str(val_score)) |
| |
| writer.add_scalar('Metrics/Mean_IoU', val_score['Mean IoU'], cur_itrs) |
| writer.add_scalar('Metrics/Overall_Acc', val_score['Overall Acc'], cur_itrs) |
| writer.add_scalar('Metrics/Mean_Acc', val_score['Mean Acc'], cur_itrs) |
|
|
| if val_score['Mean IoU'] > best_score: |
| best_score = val_score['Mean IoU'] |
| save_ckpt(f'checkpoints/best_{opts.model}_{opts.dataset}_os{opts.output_stride}.pth') |
| with open(f'checkpoints/best_score.txt', 'a') as f: |
| f.write(f"iter:{cur_itrs}\n{str(best_score)}\n") |
| with open(f"final_info.json", "w") as f: |
| final_info = { |
| "voc12_aug": { |
| "means": { |
| "mIoU": val_score['Mean IoU'], |
| "OA": val_score['Overall Acc'], |
| "mAcc": val_score['Mean IoU'] |
| } |
| } |
| } |
| json.dump(final_info, f, indent=4) |
|
|
| model.train() |
| scheduler.step() |
|
|
| if cur_itrs >= opts.total_itrs: |
| writer.close() |
| return |
|
|
|
|
| if __name__ == '__main__': |
| args = get_argparser().parse_args() |
| try: |
| main(args) |
| except Exception as e: |
| import traceback |
| print("Original error in subprocess:", flush=True) |
| traceback.print_exc(file=open("traceback.log", "w")) |
| raise |
|
|