Spaces:
Sleeping
Sleeping
| import torch, torchvision, os | |
| from PIL import Image | |
| # %% image loading | |
| def hfImageToTensor(image, width:int=1024, height:int=512)->torch.Tensor: | |
| """ | |
| Convert an input image (PIL.Image or numpy array) from Hugging Face to a torch tensor | |
| of shape (3, height, width) and type float32. | |
| Args: | |
| image: Input image (PIL.Image or numpy array). | |
| width (int): Target width. | |
| height (int): Target height. | |
| Returns: | |
| torch.Tensor: Image tensor of shape (3, height, width). | |
| """ | |
| image = image if isinstance(image, torch.Tensor) else torchvision.transforms.functional.to_tensor(image) | |
| return torchvision.transforms.functional.resize(image, [height, width]) | |
| # %% preprocessing | |
| def preprocessing(image_tensor: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Standardize the image tensor and add batch dimension. | |
| Args: | |
| image_tensor (torch.Tensor): Image tensor of shape (3, H, W). | |
| Returns: | |
| torch.Tensor: Preprocessed tensor of shape (1, 3, H, W). | |
| """ | |
| return torchvision.transforms.functional.normalize( | |
| image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ).unsqueeze(0) | |
| # %% print mask on a sem seg style | |
| def print_mask(mask:torch.Tensor, numClasses:int=19)->None: | |
| """ | |
| Visualizes the segmentation mask by mapping each class to a specific color. | |
| Args: | |
| mask (torch.Tensor): The segmentation mask to visualize. | |
| numClasses (int, optional): Number of classes in the segmentation mask. Defaults to 19. | |
| """ | |
| colors = [ | |
| (128, 64, 128), # 0: road | |
| (244, 35, 232), # 1: sidewalk | |
| (70, 70, 70), # 2: building | |
| (102, 102, 156), # 3: wall | |
| (190, 153, 153), # 4: fence | |
| (153, 153, 153), # 5: pole | |
| (250, 170, 30), # 6: traffic light | |
| (220, 220, 0), # 7: traffic sign | |
| (107, 142, 35), # 8: vegetation | |
| (152, 251, 152), # 9: terrain | |
| (70, 130, 180), # 10: sky | |
| (220, 20, 60), # 11: person | |
| (255, 0, 0), # 12: rider | |
| (0, 0, 142), # 13: car | |
| (0, 0, 70), # 14: truck | |
| (0, 60, 100), # 15: bus | |
| (0, 80, 100), # 16: train | |
| (0, 0, 230), # 17: motorcycle | |
| (119, 11, 32) # 18: bicycle | |
| ] | |
| new_mask = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8) | |
| new_mask[mask == 255] = torch.tensor([0, 0, 0], dtype=torch.uint8) | |
| for i in range (numClasses): | |
| new_mask[mask == i] = torch.tensor(colors[i][:3], dtype=torch.uint8) | |
| return new_mask.permute(2,0,1) | |
| def legendHandling()->list[int, str, str]: | |
| """ | |
| Returns a sorted list of tuples containing class IDs, names, and colors for semantic segmentation. | |
| Each tuple contains: | |
| - Class ID (int) | |
| - Class name (str) | |
| - Color description (str) | |
| The list is sorted by class ID. | |
| """ | |
| return sorted([[0, "road", "dark purple", (128, 64, 128)], [1, "sidewalk", "light purple / pink", (244, 35, 232)], [2, "building", "dark gray", (70, 70, 70)], | |
| [3, "wall", "blue + grey", (102, 102, 156)], [4, "fence", "beige", (190, 153, 153)], [5, "pole", "grey", (153, 153, 153)], [6, "traffic light", "orange", (250, 170, 30)], | |
| [7, "traffic sign", "yellow", (220, 220, 0)], [8, "vegetation", "dark green", (107, 142, 35)], [9, "terrain", "light green", (152, 251, 152)], [10, "sky", "blue", (70, 130, 180)], | |
| [11, "person", "dark red", (220, 20, 60)], [12, "rider", "light red", (255, 0, 0)], [13, "car", "blue", (0, 0, 142)], [14, "truck", "dark blue", (0, 0, 70)], | |
| [15, "bus", "dark blue", (0, 60, 100)], [16, "train", "blue + green", (0, 80, 100)], [17, "motorcycle", "light blue", (0, 0, 230)], [18, "bicycle", "velvet", (119, 11, 32)] | |
| ], key=lambda x: x[0]) | |
| # %% postprocessing | |
| def postprocessing(pred: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Convert the model's output tensor to a format suitable for visualization. | |
| Args: | |
| pred (torch.Tensor): Model output tensor of shape (1, H, W). | |
| Returns: | |
| torch.Tensor: Processed tensor of shape (3, H, W) for visualization. | |
| """ | |
| return torchvision.transforms.functional.to_pil_image(print_mask(pred.squeeze(0).cpu().to(torch.uint8))) | |
| # %% preloaded images | |
| def loadPreloadedImages(*args:str) -> list[tuple[Image.Image, str]]: | |
| """ | |
| Load preloaded images from a directory. | |
| Args: | |
| args (str): Path to the directory containing images. | |
| Returns: | |
| images (list[tuple[Image.Image, str]]): List of loaded images with their original paths. | |
| """ | |
| return list(map(lambda x:x[0], sorted([[Image.open(os.path.join(imageDir, image)).convert("RGB"), os.path.join(imageDir, image)] | |
| for imageDir in args for image in os.listdir(imageDir) if image.endswith((".png", ".jpg", "jpeg"))], key=lambda x: x[1]))) |