starry / backend /python-services /services /semantic_service.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
Raw
History Blame Contribute Delete
11.2 kB
"""
Semantic prediction service.
Detects and classifies musical symbols (notes, rests, clefs, etc.).
Supports both single-model and multi-model cluster directories.
"""
import os
import re
import math
import numpy as np
import torch
import cv2
import yaml
import logging
from predictors.torchscript_predictor import TorchScriptPredictor, resolve_model_path
from common.image_utils import (
array_from_image_stream, slice_feature, splice_output_tensor,
MARGIN_DIVIDER
)
from common.transform import Composer
VERTICAL_UNITS = 24.
POINT_RADIUS_MAX = 8
def detect_points(heatmap, vertical_units=24):
"""Detect point features (notes, symbols) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
blur_kernel = (heatmap.shape[0] // 128) * 2 + 1
if blur_kernel > 1:
heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0)
else:
heatmap_blur = heatmap
thresh = cv2.adaptiveThreshold(
heatmap_blur, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
points = []
for c in contours:
(x, y), radius = cv2.minEnclosingCircle(c)
confidence = 0
for px in range(max(math.floor(x - radius), 0), min(math.ceil(x + radius), heatmap.shape[1])):
for py in range(max(math.floor(y - radius), 0), min(math.ceil(y + radius), heatmap.shape[0])):
confidence += heatmap[py, px] / 255.
if radius < POINT_RADIUS_MAX:
points.append({
'mark': (x, y, radius),
'x': x / unit,
'y': (y - y0) / unit,
'confidence': float(confidence),
})
return points
def detect_vlines(heatmap, vertical_units=24):
"""Detect vertical line features (barlines, stems) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
blur_kernel = (heatmap.shape[0] // 128) * 2 + 1
if blur_kernel > 1:
heatmap_blur = cv2.GaussianBlur(heatmap, (blur_kernel, blur_kernel), 0)
else:
heatmap_blur = heatmap
thresh = cv2.adaptiveThreshold(
heatmap_blur, 255,
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 3, 0
)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
lines = []
for contour in contours:
left, top, width, height = cv2.boundingRect(contour)
x = (left + width / 2) / unit
y1 = (top - y0) / unit
y2 = (top + height - y0) / unit
confidence = 0
for px in range(left, left + width):
for py in range(top, top + height):
confidence += heatmap[py, px] / 255.
length = max(height, 2.5)
confidence /= length * 0.8
lines.append({
'x': x,
'y': y1,
'extension': {'y1': y1, 'y2': y2},
'confidence': float(confidence),
'mark': (left + width / 2, top, top + height),
})
return lines
def detect_rectangles(heatmap, vertical_units=24):
"""Detect rectangular features (text boxes) in heatmap."""
unit = heatmap.shape[0] / vertical_units
y0 = heatmap.shape[0] / 2.0
_, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for contour in contours:
left, top, width, height = cv2.boundingRect(contour)
if width * height / unit / unit < 2:
continue
x = (left + width / 2) / unit
y = (top + height / 2) / unit
confidence = 0
for px in range(left, left + width):
for py in range(top, top + height):
confidence += heatmap[py, px] / 255.
confidence /= width * height
rects.append({
'x': x,
'y': y - y0 / unit,
'extension': {'width': width / unit, 'height': height / unit},
'confidence': float(confidence),
'mark': (left, top, width, height),
})
return rects
def detect_boxes(heatmap, vertical_units=24):
"""Detect rotated box features in heatmap."""
unit = heatmap.shape[0] / vertical_units
_, thresh = cv2.threshold(heatmap, 92, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
rects = []
for contour in contours:
rect = cv2.minAreaRect(contour)
pos, size, theta = rect
confidence = math.sqrt(size[0] * size[1])
if min(*size) / unit < 8:
continue
rects.append({
'x': pos[0],
'y': pos[1],
'extension': {'width': size[0], 'height': size[1], 'theta': theta},
'confidence': confidence,
'mark': rect,
})
return rects
class ScoreSemantic:
"""Score semantic analysis results."""
def __init__(self, heatmaps, labels, confidence_table=None):
self.data = {
'__prototype': 'SemanticGraph',
'points': [],
'staffY': None,
}
assert len(labels) == len(heatmaps), \
f'classes - heatmaps count mismatch: {len(labels)} - {len(heatmaps)}'
for i, semantic in enumerate(labels):
mean_confidence = 1
if confidence_table is not None:
item = confidence_table[i]
assert item['semantic'] == semantic
mean_confidence = max(item['mean_confidence'], 1e-4)
if re.match(r'^vline_', semantic):
lines = detect_vlines(heatmaps[i], vertical_units=VERTICAL_UNITS)
for line in lines:
self.data['points'].append({
'semantic': semantic,
'x': line['x'],
'y': line['y'],
'extension': line['extension'],
'confidence': line['confidence'] / mean_confidence,
})
elif re.match(r'^rect_', semantic):
rectangles = detect_rectangles(heatmaps[i], vertical_units=VERTICAL_UNITS)
for rect in rectangles:
self.data['points'].append({
'semantic': semantic,
'x': rect['x'],
'y': rect['y'],
'extension': rect['extension'],
'confidence': rect['confidence'] / mean_confidence,
})
elif re.match(r'^box_', semantic):
boxes = detect_boxes(heatmaps[i], vertical_units=VERTICAL_UNITS)
for rect in boxes:
self.data['points'].append({
'semantic': semantic,
'x': rect['x'],
'y': rect['y'],
'extension': rect['extension'],
'confidence': rect['confidence'] / mean_confidence,
})
else:
points = detect_points(heatmaps[i], vertical_units=VERTICAL_UNITS)
for point in points:
self.data['points'].append({
'semantic': semantic,
'x': point['x'],
'y': point['y'],
'confidence': point['confidence'] / mean_confidence,
})
def json(self):
return self.data
def _is_cluster_dir(model_path):
"""Check if model_path is a semantic cluster directory (has 'subs' in .state.yaml)."""
if not os.path.isdir(model_path):
return False
state_file = os.path.join(model_path, '.state.yaml')
if not os.path.exists(state_file):
return False
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
return 'subs' in state
class SemanticService:
"""Semantic prediction service.
Handles both single TorchScript models and multi-model cluster directories.
A cluster directory has a .state.yaml with 'subs' listing sub-model directories.
"""
DEFAULT_TRANS = ['Mono', 'HWC2CHW']
DEFAULT_SLICING_WIDTH = 512
def __init__(self, model_path, device='cuda', trans=None, slicing_width=None,
labels=None, confidence_table=None, **kwargs):
self.device = device
if _is_cluster_dir(model_path):
self._init_cluster(model_path, device, trans, slicing_width)
else:
self._init_single(model_path, device, trans, slicing_width, labels, confidence_table)
def _init_single(self, model_path, device, trans, slicing_width, labels, confidence_table):
"""Initialize with a single TorchScript model."""
resolved = resolve_model_path(model_path)
self.model = torch.jit.load(resolved, map_location=device)
self.model.eval()
self.sub_models = None
self.composer = Composer(trans or self.DEFAULT_TRANS)
self.slicing_width = slicing_width or self.DEFAULT_SLICING_WIDTH
self.labels = labels or []
self.confidence_table = confidence_table
logging.info('SemanticService: single model loaded: %s', resolved)
def _init_cluster(self, model_path, device, trans, slicing_width):
"""Initialize with a multi-model cluster directory."""
state_file = os.path.join(model_path, '.state.yaml')
with open(state_file, 'r') as f:
cluster_state = yaml.safe_load(f)
# Get predictor config from cluster .state.yaml
predictor_config = cluster_state.get('predictor', {})
self.composer = Composer(
trans or predictor_config.get('trans') or self.DEFAULT_TRANS
)
self.slicing_width = (
slicing_width
or predictor_config.get('slicing_width')
or self.DEFAULT_SLICING_WIDTH
)
# Confidence table dict from cluster config
ct_dict = predictor_config.get('confidence_table', {})
# Load each sub-model
self.sub_models = []
self.labels = []
subs = cluster_state.get('subs', [])
for sub_name in subs:
sub_dir = os.path.join(model_path, sub_name)
sub_state_file = os.path.join(sub_dir, '.state.yaml')
with open(sub_state_file, 'r') as f:
sub_state = yaml.safe_load(f)
sub_labels = sub_state.get('data', {}).get('args', {}).get('labels', [])
sub_model_file = resolve_model_path(sub_dir)
model = torch.jit.load(sub_model_file, map_location=device)
model.eval()
self.sub_models.append(model)
self.labels.extend(sub_labels)
logging.info(' sub-model %s: %d labels, file=%s',
sub_name, len(sub_labels), os.path.basename(sub_model_file))
# Build confidence table list matching label order
self.confidence_table = None
if ct_dict:
self.confidence_table = []
for label in self.labels:
mean_conf = ct_dict.get(label, 1.0)
self.confidence_table.append({
'semantic': label,
'mean_confidence': mean_conf,
})
self.model = None # not used for cluster
logging.info('SemanticService: cluster loaded with %d sub-models, %d total labels',
len(self.sub_models), len(self.labels))
def run_inference(self, batch):
"""Run model inference with no_grad context."""
with torch.no_grad():
if self.sub_models is not None:
# Cluster: run each sub-model and concatenate channels
outputs = []
for model in self.sub_models:
output = model(batch)
if isinstance(output, tuple):
_, semantic = output
else:
semantic = output
outputs.append(semantic)
return torch.cat(outputs, dim=1)
else:
return self.model(batch)
def predict(self, streams, **kwargs):
"""
Predict semantic symbols from image streams.
streams: list of image byte buffers
yields: semantic graph results
"""
for stream in streams:
image = array_from_image_stream(stream)
if image is None:
yield {'error': 'Invalid image'}
continue
# Slice image
pieces = list(slice_feature(
image,
width=self.slicing_width,
overlapping=2 / MARGIN_DIVIDER,
padding=True
))
pieces = np.array(pieces, dtype=np.uint8)
# Transform
staves, _ = self.composer(pieces, np.ones((1, 4, 4, 2)))
batch = torch.from_numpy(staves).to(self.device)
# Inference
output = self.run_inference(batch)
# Handle tuple output (single model case)
if isinstance(output, tuple):
_, output = output
semantic = splice_output_tensor(output)
# Build semantic result
ss = ScoreSemantic(
np.uint8(semantic * 255),
self.labels,
confidence_table=self.confidence_table
)
yield ss.json()