| import torch
|
| from transformers import AutoModel, AutoTokenizer
|
|
|
| class ModelHandler:
|
| def __init__(self):
|
| self.model = None
|
| self.tokenizer = None
|
|
|
| def initialize(self, model_path):
|
| """Load model and tokenizer."""
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| self.model = AutoModel.from_pretrained(model_path)
|
|
|
| def preprocess(self, data):
|
| """Preprocess input data."""
|
| text = data.get("text", "")
|
| inputs = self.tokenizer(text, return_tensors="pt")
|
| return inputs
|
|
|
| def inference(self, inputs):
|
| """Run inference on the model."""
|
| outputs = self.model(**inputs)
|
| return outputs
|
|
|
| def postprocess(self, outputs):
|
| """Postprocess model output."""
|
| return {"output": outputs.logits.tolist()}
|
|
|
| _handler = ModelHandler()
|
|
|
| def handle(data, context):
|
| if not _handler.model:
|
| model_path = context.system_properties.get("model_dir")
|
| _handler.initialize(model_path)
|
|
|
| if data is None:
|
| return {"error": "No input data"}
|
|
|
| inputs = _handler.preprocess(data[0])
|
| outputs = _handler.inference(inputs)
|
| return _handler.postprocess(outputs)
|
|
|