CodeSense / codesense /similarity.py
Yooshiii's picture
Upload 36 files
f8a39f0 verified
import torch
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from .embedder import CodeT5Embedder
# -------- Singleton Embedder --------
_embedder = CodeT5Embedder()
# ============================================================
# ML v2 PROTOTYPE STRUCTURE
# Category → Algorithm → [Variants]
# ============================================================
PROTOTYPES = {
"Sorting Algorithm": {
"Bubble Sort": [
# Classic
"""
def bubble_sort(arr):
for i in range(len(arr)):
for j in range(len(arr)-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
""",
# Optimized (swapped flag)
"""
def bubble_sort(arr):
n = len(arr)
for i in range(n):
swapped = False
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
swapped = True
if not swapped:
break
"""
],
"Insertion Sort": [
# Shift-based
"""
def insertion_sort(arr):
for i in range(1, len(arr)):
key = arr[i]
j = i - 1
while j >= 0 and arr[j] > key:
arr[j+1] = arr[j]
j -= 1
arr[j+1] = key
""",
# Swap-based variant
"""
def insertion_sort(arr):
for i in range(1, len(arr)):
j = i
while j > 0 and arr[j] < arr[j-1]:
arr[j], arr[j-1] = arr[j-1], arr[j]
j -= 1
"""
],
"Merge Sort": [
# Slicing-based
"""
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr)//2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
""",
# Index-based (GFG style)
"""
def merge(arr, l, m, r):
n1 = m - l + 1
n2 = r - m
L = [0] * n1
R = [0] * n2
for i in range(n1):
L[i] = arr[l + i]
for j in range(n2):
R[j] = arr[m + 1 + j]
i = j = 0
k = l
while i < n1 and j < n2:
if L[i] <= R[j]:
arr[k] = L[i]
i += 1
else:
arr[k] = R[j]
j += 1
k += 1
def merge_sort(arr, l, r):
if l < r:
m = l + (r - l)//2
merge_sort(arr, l, m)
merge_sort(arr, m+1, r)
merge(arr, l, m, r)
"""
],
"Quick Sort": [
# List-comprehension variant
"""
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[0]
left = [x for x in arr[1:] if x <= pivot]
right = [x for x in arr[1:] if x > pivot]
return quick_sort(left) + [pivot] + quick_sort(right)
""",
# Partition-based (GFG style)
"""
def partition(arr, low, high):
pivot = arr[high]
i = low - 1
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i+1], arr[high] = arr[high], arr[i+1]
return i+1
def quick_sort(arr, low, high):
if low < high:
pi = partition(arr, low, high)
quick_sort(arr, low, pi-1)
quick_sort(arr, pi+1, high)
"""
],
"Heap Sort": [
# heapq-based
"""
import heapq
def heap_sort(arr):
heapq.heapify(arr)
return [heapq.heappop(arr) for _ in range(len(arr))]
""",
# Manual heapify (GFG style)
"""
def heapify(arr, n, i):
largest = i
l = 2*i + 1
r = 2*i + 2
if l < n and arr[l] > arr[largest]:
largest = l
if r < n and arr[r] > arr[largest]:
largest = r
if largest != i:
arr[i], arr[largest] = arr[largest], arr[i]
heapify(arr, n, largest)
def heap_sort(arr):
n = len(arr)
for i in range(n//2 - 1, -1, -1):
heapify(arr, n, i)
for i in range(n-1, 0, -1):
arr[i], arr[0] = arr[0], arr[i]
heapify(arr, i, 0)
"""
]
},
"Dynamic Programming": {
"Memoization": [
"""
memo = {}
def fib(n):
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fib(n-1) + fib(n-2)
return memo[n]
"""
],
"Tabulation": [
"""
def fib(n):
dp = [0]*(n+1)
dp[1] = 1
for i in range(2, n+1):
dp[i] = dp[i-1] + dp[i-2]
return dp[n]
""",
"""
def knapsack(weights, values, capacity):
n = len(weights)
dp = [[0]*(capacity+1) for _ in range(n+1)]
for i in range(1, n+1):
for w in range(capacity+1):
if weights[i-1] <= w:
dp[i][w] = max(values[i-1] + dp[i-1][w-weights[i-1]],
dp[i-1][w])
else:
dp[i][w] = dp[i-1][w]
"""
]
},
"Graph Algorithm": {
"Breadth-First Search": [
"""
from collections import deque
def bfs(graph, start):
visited = set()
queue = deque([start])
while queue:
node = queue.popleft()
for neighbor in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
"""
],
"Depth-First Search": [
# Recursive
"""
def dfs(graph, node, visited):
visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
dfs(graph, neighbor, visited)
""",
# Iterative
"""
def dfs(graph, start):
visited = set()
stack = [start]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
for neighbor in graph[node]:
stack.append(neighbor)
"""
]
},
"Pointer Technique": {
"Two-Pointer Technique": [
"""
def two_sum_sorted(arr, target):
left, right = 0, len(arr)-1
while left < right:
s = arr[left] + arr[right]
if s == target:
return True
elif s < target:
left += 1
else:
right -= 1
"""
],
"Sliding Window": [
"""
def max_subarray(arr, k):
current_sum = 0
left = 0
for right in range(len(arr)):
current_sum += arr[right]
if right-left+1 > k:
current_sum -= arr[left]
left += 1
"""
]
},
"Search Algorithm": {
"Binary Search": [
"""
def binary_search(arr, target):
left, right = 0, len(arr)-1
while left <= right:
mid = (left+right)//2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid+1
else:
right = mid-1
"""
]
},
"Data Structure Based": {
"Heap-Based Algorithm": [
"""
import heapq
def top_k(nums, k):
heap = []
for num in nums:
heapq.heappush(heap, num)
if len(heap) > k:
heapq.heappop(heap)
"""
]
}
}
# ============================================================
# PRECOMPUTE EMBEDDINGS
# ============================================================
_PROTOTYPE_EMBEDDINGS = {}
for category, algorithms in PROTOTYPES.items():
_PROTOTYPE_EMBEDDINGS[category] = {}
for algo_name, variants in algorithms.items():
_PROTOTYPE_EMBEDDINGS[category][algo_name] = [
_embedder.embed(code) for code in variants
]
# ============================================================
# ML v2 PREDICTION
# ============================================================
def predict_algorithm(code: str) -> dict:
user_embedding = _embedder.embed(code)
best_algorithm = None
best_category = None
best_score = -1.0
category_scores = {}
for category, algorithms in _PROTOTYPE_EMBEDDINGS.items():
category_best = -1.0
for algo_name, variant_embeddings in algorithms.items():
for proto_embedding in variant_embeddings:
similarity = F.cosine_similarity(
torch.tensor(user_embedding).unsqueeze(0),
torch.tensor(proto_embedding).unsqueeze(0)
).item()
# Track global best
if similarity > best_score:
best_score = similarity
best_algorithm = algo_name
best_category = category
# Track best per category
if similarity > category_best:
category_best = similarity
category_scores[category] = round(category_best, 3)
return {
"ml_prediction": best_algorithm,
"ml_category": best_category,
"confidence": round(best_score, 3),
"category_scores": category_scores
}