| | from typing import Optional, Union |
| | from tqdm.auto import trange |
| | from PIL import ImageOps |
| | from PIL import Image |
| | from torch import nn |
| | import numpy as np |
| | import torch |
| | import cv2 |
| |
|
| |
|
| | class MidasDepth(nn.Module): |
| | def __init__(self, model_type="DPT_Large", |
| | device=torch.device( |
| | "cuda" if torch.cuda.is_available() else "cpu"), |
| | is_inpainting=False): |
| | super().__init__() |
| | self.device = device |
| | if self.device.type == "mps": |
| | self.device = torch.device("cpu") |
| | self.model = torch.hub.load( |
| | "intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False) |
| | self.transform = torch.hub.load( |
| | "intel-isl/MiDaS", "transforms").dpt_transform |
| |
|
| | @torch.no_grad() |
| | def forward(self, image): |
| | if torch.is_tensor(image): |
| | image = image.cpu().detach() |
| | if not isinstance(image, np.ndarray): |
| | image = np.asarray(image) |
| | image = image.squeeze() |
| | batch = self.transform(image).to(self.device) |
| | prediction = self.model(batch) |
| | prediction = torch.nn.functional.interpolate( |
| | prediction.unsqueeze(1), |
| | size=image.shape[-3:-1], |
| | mode="bicubic", |
| | align_corners=False, |
| | )[:, 0] |
| | |
| | |
| | return prediction |
| |
|
| | @torch.no_grad() |
| | def get_depth(self, img): |
| | im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255. |
| | og_depth = self(im.unsqueeze(0) * 255.)[0] |
| | d = og_depth |
| | d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3 |
| | d = 30 / d |
| | |
| | |
| | |
| | return d.detach().cpu().numpy() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from matplotlib import pyplot as plt |
| | plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg"))) |
| | plt.show() |
| |
|