| import shutil |
| import traceback |
| from io import BytesIO |
| from urllib.parse import urlparse |
|
|
| import cv2 |
| import numpy as np |
| import pydicom |
| import requests |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from transformers import BitImageProcessor, BlipImageProcessor |
|
|
|
|
| @torch.no_grad() |
| def model_inference(image, text, model, image_processor, tokenizer): |
| image = load_image(image) |
|
|
| (width, height) = image.size |
|
|
| image_size = (height, width) |
|
|
| image_processor_outputs = image_processor(image) |
|
|
| processed_image = torch.FloatTensor( |
| np.array(image_processor_outputs["pixel_values"]) |
| ).to(model.device) |
|
|
| tokenized_text = tokenizer( |
| text, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| output = model.compute_logits(processed_image, [tokenized_text]) |
| logits = output["logits"] |
| similarity_prob = logits.sigmoid() |
|
|
| similarity_scores = output["similarity_scores"] |
| similarity_scores = similarity_scores.view(-1) |
|
|
| similarity_scores = interpolate_similarity_scores( |
| similarity_scores, image_size, image_processor |
| ) |
| similarity_map = similarity_scores.sigmoid()[0] |
|
|
| return similarity_prob, similarity_map |
|
|
|
|
| @torch.no_grad() |
| def model_inference_multiple_text(image, text_list, model, image_processor, tokenizer): |
| |
| probs, similarity_maps = [], [] |
| for text in text_list: |
| prob, similarity_map = model_inference( |
| image, text, model, image_processor, tokenizer |
| ) |
| probs.append(prob) |
| similarity_maps.append(similarity_map) |
|
|
| return torch.stack(probs), torch.stack(similarity_maps) |
|
|
|
|
| def interpolate_similarity_scores(similarity_scores, origin_size, image_processor): |
| (height, width) = origin_size |
| patch_size = int(similarity_scores.shape[-1] ** 0.5) |
| scores = similarity_scores.view(1, 1, patch_size, patch_size) |
|
|
| if isinstance(image_processor, BlipImageProcessor): |
| |
| interpolated_scores = F.interpolate( |
| scores, |
| size=(height, width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| interpolated_scores = interpolated_scores.squeeze(1) |
|
|
| elif isinstance(image_processor, BitImageProcessor): |
| shortest = min(height, width) |
|
|
| interpolated_scores = F.interpolate( |
| scores, |
| size=(shortest, shortest), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| cropped_left = (width - shortest) // 2 |
| cropped_top = (height - shortest) // 2 |
|
|
| original_size_map = torch.ones(height, width) * -999 |
| original_size_map[ |
| cropped_top : cropped_top + shortest, cropped_left : cropped_left + shortest |
| ] = interpolated_scores.view(shortest, shortest) |
|
|
| interpolated_scores = original_size_map |
| interpolated_scores = interpolated_scores.unsqueeze(0) |
|
|
| return interpolated_scores |
|
|
|
|
| |
| def dicom_to_pil_image(input_file_path, save_dir=None): |
| """ |
| Extract the image from a DICOM file and return it as a PIL.Image object. |
| Args: |
| input_file_path (str): Path to the input DICOM file. |
| Returns: |
| PIL.Image.Image: Processed image. |
| """ |
| try: |
| |
| dcm_file = pydicom.dcmread(input_file_path) |
| raw_image = dcm_file.pixel_array |
|
|
| assert len(raw_image.shape) == 2, "Expecting single channel (grayscale) image." |
|
|
| |
| raw_image = raw_image - raw_image.min() |
| normalized_image = raw_image / raw_image.max() |
| rescaled_image = (normalized_image * 255).astype(np.uint8) |
|
|
| |
| if dcm_file.PhotometricInterpretation == "MONOCHROME1": |
| rescaled_image = cv2.bitwise_not(rescaled_image) |
|
|
| |
| final_image = cv2.equalizeHist(rescaled_image) |
|
|
| |
| image = Image.fromarray(final_image) |
|
|
| if save_dir is not None: |
| shutil.copy2(input_file_path, save_dir) |
|
|
| return image |
| except Exception: |
| print(traceback.format_exc()) |
|
|
|
|
| def load_image(image): |
| """ |
| Load an image from a file path or a PIL.Image object. |
| Args: |
| image (str or PIL.Image.Image): Path to the image file or a PIL.Image object. |
| Returns: |
| PIL.Image.Image: Processed image. |
| """ |
|
|
| if isinstance(image, str): |
| if image.lower().endswith(".dcm"): |
| image = dicom_to_pil_image(image) |
| elif ( |
| image.lower().endswith(".png") |
| or image.lower().endswith(".jpg") |
| or image.lower().endswith(".jpeg") |
| ): |
| image = Image.open(image) |
| else: |
| raise ValueError(f"Invalid image type: {image}") |
| elif not isinstance(image, Image.Image): |
| raise ValueError(f"Invalid image type: {type(image)}") |
|
|
| return image |
|
|