aoiandroid wybxc commited on
Commit
4a5aa99
·
0 Parent(s):

Duplicate from wybxc/DocLayout-YOLO-DocStructBench-onnx

Browse files

Co-authored-by: wybxc <wybxc@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - juliozhao/DocLayout-YOLO-DocStructBench
5
+ ---
6
+ Converted from [juliozhao/DocLayout-YOLO-DocStructBench](https://huggingface.co/juliozhao/DocLayout-YOLO-DocStructBench)
doclayout_yolo_docstructbench_imgsz1024.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fece9af02f618b603ff7921ccec6861d13e7e1f9830e091dfb7e8ad9311e5b21
3
+ size 75324598
inference.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import onnx
3
+ import onnxruntime as ort
4
+ import cv2
5
+ from huggingface_hub import hf_hub_download
6
+ import numpy as np
7
+
8
+ # Download the model from the Hugging Face Hub
9
+ model = hf_hub_download(
10
+ repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx",
11
+ filename="doclayout_yolo_docstructbench_imgsz1024.onnx",
12
+ )
13
+ model = onnx.load(model)
14
+ metadata = {prop.key: prop.value for prop in model.metadata_props}
15
+
16
+ names = ast.literal_eval(metadata["names"])
17
+ stride = ast.literal_eval(metadata["stride"])
18
+
19
+ # Load the model with ONNX Runtime
20
+ session = ort.InferenceSession(model.SerializeToString())
21
+
22
+
23
+ def resize_and_pad_image(image, new_shape, stride=32):
24
+ """
25
+ Resize and pad the image to the specified size, ensuring dimensions are multiples of stride.
26
+
27
+ Parameters:
28
+ - image: Input image
29
+ - new_shape: Target size (integer or (height, width) tuple)
30
+ - stride: Padding alignment stride, default 32
31
+
32
+ Returns:
33
+ - Processed image
34
+ """
35
+ if isinstance(new_shape, int):
36
+ new_shape = (new_shape, new_shape)
37
+
38
+ h, w = image.shape[:2]
39
+ new_h, new_w = new_shape
40
+
41
+ # Calculate scaling ratio
42
+ r = min(new_h / h, new_w / w)
43
+ resized_h, resized_w = int(round(h * r)), int(round(w * r))
44
+
45
+ # Resize image
46
+ image = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)
47
+
48
+ # Calculate padding size and align to stride multiple
49
+ pad_w = (new_w - resized_w) % stride
50
+ pad_h = (new_h - resized_h) % stride
51
+ top, bottom = pad_h // 2, pad_h - pad_h // 2
52
+ left, right = pad_w // 2, pad_w - pad_w // 2
53
+
54
+ # Add padding
55
+ image = cv2.copyMakeBorder(
56
+ image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
57
+ )
58
+
59
+ return image
60
+
61
+
62
+ def scale_boxes(img1_shape, boxes, img0_shape):
63
+ """
64
+ Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
65
+ specified in (img1_shape) to the shape of a different image (img0_shape).
66
+
67
+ Args:
68
+ img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
69
+ boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
70
+ img0_shape (tuple): the shape of the target image, in the format of (height, width).
71
+
72
+ Returns:
73
+ boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
74
+ """
75
+
76
+ # Calculate scaling ratio
77
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
78
+
79
+ # Calculate padding size
80
+ pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
81
+ pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
82
+
83
+ # Remove padding and scale boxes
84
+ boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain
85
+ return boxes
86
+
87
+
88
+ class YoloResult:
89
+ def __init__(self, boxes, names):
90
+ self.boxes = [YoloBox(data=d) for d in boxes]
91
+ self.boxes = sorted(self.boxes, key=lambda x: x.conf, reverse=True)
92
+ self.names = names
93
+
94
+
95
+ class YoloBox:
96
+ def __init__(self, data):
97
+ self.xyxy = data[:4]
98
+ self.conf = data[-2]
99
+ self.cls = data[-1]
100
+
101
+
102
+ def inference(image):
103
+ """
104
+ Run inference on the input image.
105
+
106
+ Parameters:
107
+ - image: Input image, HWC format and RGB order
108
+
109
+ Returns:
110
+ - YoloResult object containing the predicted boxes and class names
111
+ """
112
+
113
+ # Preprocess image
114
+ orig_h, orig_w = image.shape[:2]
115
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
116
+ pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
117
+ pix = np.transpose(pix, (2, 0, 1)) # CHW
118
+ pix = np.expand_dims(pix, axis=0) # BCHW
119
+ pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1]
120
+ new_h, new_w = pix.shape[2:]
121
+
122
+ # Run inference
123
+ preds = session.run(None, {"images": pix})[0]
124
+
125
+ # Postprocess predictions
126
+ preds = preds[preds[..., 4] > 0.25]
127
+ preds[..., :4] = scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w))
128
+ return YoloResult(boxes=preds, names=names)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import sys
133
+ import matplotlib
134
+ import matplotlib.pyplot as plt
135
+ import matplotlib.colors as colors
136
+
137
+ image = sys.argv[1]
138
+ image = cv2.imread(image)
139
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
140
+
141
+ layout = inference(image)
142
+
143
+ bitmap = np.ones(image.shape[:2], dtype=np.uint8)
144
+ h, w = bitmap.shape
145
+ vcls = ["abandon", "figure", "table", "isolate_formula", "formula_caption"]
146
+ for i, d in enumerate(layout.boxes):
147
+ x0, y0, x1, y1 = d.xyxy.squeeze()
148
+ x0, y0, x1, y1 = (
149
+ np.clip(int(x0 - 1), 0, w - 1),
150
+ np.clip(int(h - y1 - 1), 0, h - 1),
151
+ np.clip(int(x1 + 1), 0, w - 1),
152
+ np.clip(int(h - y0 + 1), 0, h - 1),
153
+ )
154
+ if layout.names[int(d.cls)] in vcls:
155
+ bitmap[y0:y1, x0:x1] = 0
156
+ else:
157
+ bitmap[y0:y1, x0:x1] = i + 2
158
+ bitmap = bitmap[::-1, :]
159
+
160
+ # map bitmap to color
161
+ colormap = matplotlib.colormaps["Pastel1"]
162
+ norm = colors.Normalize(vmin=bitmap.min(), vmax=bitmap.max())
163
+ colored_bitmap = colormap(norm(bitmap))
164
+ colored_bitmap = (colored_bitmap[:, :, :3] * 255).astype(np.uint8)
165
+
166
+ # overlay bitmap on image
167
+ image_with_bitmap = cv2.multiply(image, colored_bitmap, scale=1 / 255)
168
+
169
+ # show the results
170
+ fig, ax = plt.subplots(1, 3, figsize=(15, 6))
171
+ ax[0].imshow(image)
172
+ ax[1].imshow(bitmap, cmap="Pastel1")
173
+ ax[2].imshow(image_with_bitmap)
174
+ plt.show()