| from collections import namedtuple |
|
|
| import numpy as np |
| import torch |
|
|
| from .detectors import build_detector |
|
|
| try: |
| import kornia |
| except: |
| pass |
| |
|
|
|
|
|
|
| def build_network(model_cfg, num_class, dataset): |
| model = build_detector( |
| model_cfg=model_cfg, num_class=num_class, dataset=dataset |
| ) |
| return model |
|
|
|
|
| def load_data_to_gpu(batch_dict): |
| for key, val in batch_dict.items(): |
| if key == 'camera_imgs': |
| batch_dict[key] = val.cuda() |
| elif not isinstance(val, np.ndarray): |
| continue |
| elif key in ['frame_id', 'metadata', 'calib', 'image_paths','ori_shape','img_process_infos']: |
| continue |
| elif key in ['images']: |
| batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous() |
| elif key in ['image_shape']: |
| batch_dict[key] = torch.from_numpy(val).int().cuda() |
| else: |
| batch_dict[key] = torch.from_numpy(val).float().cuda() |
|
|
|
|
| def model_fn_decorator(): |
| ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict']) |
|
|
| def model_func(model, batch_dict): |
| load_data_to_gpu(batch_dict) |
| ret_dict, tb_dict, disp_dict = model(batch_dict) |
|
|
| loss = ret_dict['loss'].mean() |
| if hasattr(model, 'update_global_step'): |
| model.update_global_step() |
| else: |
| model.module.update_global_step() |
|
|
| return ModelReturn(loss, tb_dict, disp_dict) |
|
|
| return model_func |
|
|