CodeSense / codesense /embedder.py
Yooshiii's picture
Upload 36 files
f8a39f0 verified
import torch
import os
from transformers import AutoTokenizer, T5EncoderModel
class CodeT5Embedder:
def __init__(self, model_name="Salesforce/codet5-base"):
print(f"⏳ Initializing CodeT5 Engine ({model_name})...")
# use_fast=False is the specific fix for the 'List' error on Windows.
try:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=False
)
except Exception as e:
print(f"⚠️ Primary loader failed, attempting fast-mode fallback: {e}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
print("⏳ Loading CodeT5 Model weights (this may take a moment)...")
self.model = T5EncoderModel.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
device_name = str(self.device).upper()
print(f"✅ CodeT5 Engine is Live on {device_name}")
def embed(self, code: str):
"""Standard method name used by similarity.py"""
return self.get_embedding(code)
def get_embedding(self, code: str):
"""Original method name for compatibility"""
if not code or not isinstance(code, str):
code = " "
inputs = self.tokenizer(
code,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Global Average Pooling of the hidden states
return outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()