| import sys |
| from model.trainer import Trainer |
|
|
| sys.path.insert(0, '.') |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.backends.cudnn as cudnn |
| from torch.nn.parallel import gather |
| import torch.optim.lr_scheduler |
|
|
| import dataset.dataset as myDataLoader |
| import dataset.Transforms as myTransforms |
| from model.metric_tool import ConfuseMatrixMeter |
| from model.utils import BCEDiceLoss, init_seed |
| from PIL import Image |
| import os |
| import time |
| import numpy as np |
| from argparse import ArgumentParser |
| from tqdm import tqdm |
|
|
|
|
| @torch.no_grad() |
| def validate(args, val_loader, model, save_masks=False): |
| model.eval() |
|
|
| |
| for m in model.modules(): |
| if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)): |
| m.track_running_stats = True |
| m.eval() |
|
|
| salEvalVal = ConfuseMatrixMeter(n_class=2) |
| epoch_loss = [] |
|
|
| if save_masks: |
| mask_dir = f"{args.savedir}/pred_masks" |
| os.makedirs(mask_dir, exist_ok=True) |
| print(f"Saving prediction masks to: {mask_dir}") |
|
|
| pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc="Validating") |
|
|
| for batch_idx, batched_inputs in pbar: |
| img, target = batched_inputs |
| |
| batch_file_names = val_loader.sampler.data_source.file_list[ |
| batch_idx * args.batch_size : (batch_idx + 1) * args.batch_size |
| ] |
| |
| pre_img = img[:, 0:3] |
| post_img = img[:, 3:6] |
|
|
| if args.onGPU: |
| pre_img = pre_img.cuda() |
| post_img = post_img.cuda() |
| target = target.cuda() |
|
|
| target = target.float() |
| output = model(pre_img, post_img) |
| loss = BCEDiceLoss(output, target) |
| pred = (output > 0.5).long() |
|
|
| if save_masks: |
| pred_np = pred.cpu().numpy().astype(np.uint8) |
| |
| print(f"\nDebug - Batch {batch_idx}: {len(batch_file_names)} files, Mask shape: {pred_np.shape}") |
| |
| try: |
| for i in range(pred_np.shape[0]): |
| if i >= len(batch_file_names): |
| print(f"Warning: Missing filename for mask {i}, using default") |
| base_name = f"batch_{batch_idx}_mask_{i}" |
| else: |
| base_name = os.path.splitext(os.path.basename(batch_file_names[i]))[0] |
| |
| single_mask = pred_np[i, 0] |
| |
| if single_mask.ndim != 2: |
| raise ValueError(f"Invalid mask shape: {single_mask.shape}") |
| |
| mask_path = f"{mask_dir}/{base_name}_pred.png" |
| Image.fromarray(single_mask * 255).save(mask_path) |
| print(f"Saved: {mask_path}") |
|
|
| except Exception as e: |
| print(f"\nError saving batch {batch_idx}: {str(e)}") |
| print(f"Current mask shape: {single_mask.shape if 'single_mask' in locals() else 'N/A'}") |
| print(f"Current file: {base_name if 'base_name' in locals() else 'N/A'}") |
|
|
| if args.onGPU and torch.cuda.device_count() > 1: |
| pred = gather(pred, 0, dim=0) |
|
|
| f1 = salEvalVal.update_cm(pr=pred.cpu().numpy(), gt=target.cpu().numpy()) |
| epoch_loss.append(loss.item()) |
|
|
| pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'F1': f"{f1:.4f}"}) |
|
|
| average_loss = sum(epoch_loss) / len(epoch_loss) |
| scores = salEvalVal.get_scores() |
| return average_loss, scores |
| |
| def ValidateSegmentation(args): |
| """完整的验证流程主函数""" |
| |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| torch.backends.cudnn.benchmark = True |
| init_seed(args.seed) |
|
|
| |
| args.savedir = os.path.join(args.savedir, |
| f"{args.file_root}_iter_{args.max_steps}_lr_{args.lr}") |
| os.makedirs(args.savedir, exist_ok=True) |
|
|
| |
| dataset_mapping = { |
| 'LEVIR': './levir_cd_256', |
| 'WHU': './whu_cd_256', |
| 'CLCD': './clcd_256', |
| 'SYSU': './sysu_256', |
| 'OSCD': './oscd_256' |
| } |
| args.file_root = dataset_mapping.get(args.file_root, args.file_root) |
|
|
| |
| model = Trainer(args.model_type).float() |
| if args.onGPU: |
| model = model.cuda() |
|
|
| |
| mean = [0.406, 0.456, 0.485, 0.406, 0.456, 0.485] |
| std = [0.225, 0.224, 0.229, 0.225, 0.224, 0.229] |
|
|
| valDataset = myTransforms.Compose([ |
| myTransforms.Normalize(mean=mean, std=std), |
| myTransforms.Scale(args.inWidth, args.inHeight), |
| myTransforms.ToTensor() |
| ]) |
|
|
| |
| test_data = myDataLoader.Dataset(file_root=args.file_root, mode="test", transform=valDataset) |
| testLoader = torch.utils.data.DataLoader( |
| test_data, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True |
| ) |
|
|
| |
| logFileLoc = os.path.join(args.savedir, args.logFile) |
| logger = open(logFileLoc, 'a' if os.path.exists(logFileLoc) else 'w') |
| if not os.path.exists(logFileLoc): |
| logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s" % |
| ('Epoch', 'Kappa', 'IoU', 'F1', 'Recall', 'Precision', 'OA')) |
| logger.flush() |
|
|
| |
| model_file_name = os.path.join(args.savedir, 'best_model.pth') |
| if not os.path.exists(model_file_name): |
| raise FileNotFoundError(f"Model file not found: {model_file_name}") |
|
|
| state_dict = torch.load(model_file_name) |
| model.load_state_dict(state_dict) |
| print(f"Loaded model from {model_file_name}") |
|
|
| |
| loss_test, score_test = validate(args, testLoader, model, save_masks=args.save_masks) |
|
|
| |
| print("\nTest Results:") |
| print(f"Loss: {loss_test:.4f}") |
| print(f"Kappa: {score_test['Kappa']:.4f}") |
| print(f"IoU: {score_test['IoU']:.4f}") |
| print(f"F1: {score_test['F1']:.4f}") |
| print(f"Recall: {score_test['recall']:.4f}") |
| print(f"Precision: {score_test['precision']:.4f}") |
| print(f"OA: {score_test['OA']:.4f}") |
|
|
| |
| logger.write("\n%s\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" % |
| ('Test', score_test['Kappa'], score_test['IoU'], score_test['F1'], |
| score_test['recall'], score_test['precision'], score_test['OA'])) |
| logger.close() |
|
|
|
|
| if __name__ == '__main__': |
| parser = ArgumentParser() |
| parser.add_argument('--file_root', default="LEVIR", |
| help='Data directory | LEVIR | WHU | CLCD | SYSU | OSCD') |
| parser.add_argument('--inWidth', type=int, default=256, help='Width of input image') |
| parser.add_argument('--inHeight', type=int, default=256, help='Height of input image') |
| parser.add_argument('--max_steps', type=int, default=80000, |
| help='Max. number of iterations (for path naming)') |
| parser.add_argument('--num_workers', type=int, default=4, |
| help='Number of data loading workers') |
| parser.add_argument('--model_type', type=str, default='small', |
| help='Model type | tiny | small') |
| parser.add_argument('--batch_size', type=int, default=16, |
| help='Batch size for validation') |
| parser.add_argument('--lr', type=float, default=2e-4, |
| help='Learning rate (for path naming)') |
| parser.add_argument('--seed', type=int, default=16, |
| help='Random seed for reproducibility') |
| parser.add_argument('--savedir', default='./results', |
| help='Base directory to save results') |
| parser.add_argument('--logFile', default='testLog.txt', |
| help='File to save validation logs') |
| parser.add_argument('--onGPU', default=True, |
| type=lambda x: (str(x).lower() == 'true'), |
| help='Run on GPU if True') |
| parser.add_argument('--gpu_id', type=int, default=0, |
| help='GPU device id') |
| parser.add_argument('--save_masks', action='store_true', |
| help='Save predicted masks to disk') |
|
|
| args = parser.parse_args() |
| print('Validation with args:') |
| print(args) |
|
|
| ValidateSegmentation(args) |
|
|
|
|