CodeSense / codesense /features.py
Yooshiii's picture
Upload 36 files
f8a39f0 verified
import ast
class FeatureExtractor(ast.NodeVisitor):
"""
Traverses the AST and extracts structural features
from Python source code.
"""
def __init__(self):
self.features = {
"for_loops": 0,
"while_loops": 0,
"function_calls": set(),
"recursion": False,
"max_loop_depth": 0,
"recursive_call_count": 0,
"divide_and_conquer": False,
"binary_search_pattern": False,
"pointer_variables": set(),
"pointer_updates": 0,
"bfs_pattern": False,
"queue_variables": set(),
"queue_operations": 0,
"queue_pop_front": False,
"queue_append_detected": False,
"graph_iteration": False,
"dfs_pattern": False,
"uses_stack": False,
"uses_pop": False,
"dp_pattern": False,
"uses_dp_array": False,
"sorting_pattern": False,
"bubble_sort_pattern": False,
"insertion_sort_pattern": False,
"adjacent_swap_detected": False,
"insertion_shift_detected": False,
"memoization_pattern": False,
"memo_dict_defined": False,
"memo_lookup_detected": False,
"memo_store_detected": False,
"tabulation_pattern": False,
"dp_self_dependency": False,
"dp_dimension": 1,
"merge_sort_pattern": False,
"quick_sort_pattern": False,
"sliding_window_pattern": False,
"window_updates": 0,
"window_shrinks": 0,
"heap_imported": False,
"heap_operations": 0,
"heap_pattern": False,
}
self.current_function = None
self.current_loop_depth = 0
self.max_loop_depth = 0
self.current_function_name = None
def visit_Import(self, node):
for alias in node.names:
if alias.name == "heapq":
self.features["heap_imported"] = True
self.generic_visit(node)
def visit_ImportFrom(self, node):
if node.module == "heapq":
self.features["heap_imported"] = True
self.generic_visit(node)
def visit_FunctionDef(self, node):
previous_function = self.current_function_name
self.current_function_name = node.name
self.generic_visit(node)
self.current_function_name = previous_function
def visit_For(self, node):
self.features["for_loops"] += 1
self.current_loop_depth += 1
self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
# Detect graph[node] iteration
if isinstance(node.iter, ast.Subscript):
self.features["graph_iteration"] = True
self.generic_visit(node)
self.current_loop_depth -= 1
if isinstance(node.target, ast.Name):
var = node.target.id.lower()
if var in ("right", "r", "end"):
self.features["window_updates"] += 1
def visit_While(self, node):
self.features["while_loops"] += 1
self.current_loop_depth += 1
self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
self.generic_visit(node)
self.current_loop_depth -= 1
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
function_name = node.func.id
self.features["function_calls"].add(function_name)
# Detect recursion
if function_name == self.current_function_name:
self.features["recursion"] = True
self.features["recursive_call_count"] += 1
# If recursion + loop present → DFS-style
if self.features["for_loops"] >= 1:
self.features["dfs_pattern"] = True
# Detect divide-and-conquer
for arg in node.args:
# Case 1: n // 2 or n / 2
if isinstance(arg, ast.BinOp) and isinstance(arg.op, (ast.FloorDiv, ast.Div)):
self.features["divide_and_conquer"] = True
# Case 2: slicing like arr[:mid]
if isinstance(arg, ast.Subscript):
if isinstance(arg.slice, ast.Slice):
self.features["divide_and_conquer"] = True
if isinstance(arg, ast.Subscript) and isinstance(arg.slice, ast.Slice):
self.features["merge_sort_pattern"] = True
# Detect queue operations
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
var = node.func.value.id
if var in self.features["queue_variables"]:
if node.func.attr in ("append", "popleft"):
self.features["queue_operations"] += 1
# Detect stack.pop() or queue.pop()
if isinstance(node.func, ast.Attribute):
method = node.func.attr
if method == "pop":
self.features["uses_pop"] = True
if method == "append":
# mark append usage
pass
# Detect pop(0) for list-based BFS
if isinstance(node.func, ast.Attribute):
method = node.func.attr
# pop(0)
if method == "pop":
if node.args and isinstance(node.args[0], ast.Constant):
if node.args[0].value == 0:
self.features["queue_pop_front"] = True
# append()
if method == "append":
self.features["queue_append_detected"] = True
# popleft()
if method == "popleft":
self.features["queue_pop_front"] = True
# Iterative DFS heuristic
if (
self.features["uses_stack"]
and self.features["uses_pop"]
and self.features["for_loops"] >= 1
):
self.features["dfs_pattern"] = True
# Heap operations
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "heapq":
if node.func.attr in ("heappush", "heappop", "heapify"):
self.features["heap_operations"] += 1
self.generic_visit(node)
def visit_Assign(self, node):
# -------- Binary Search Pattern Detection --------
if isinstance(node.value, ast.BinOp):
if isinstance(node.value.op, ast.FloorDiv):
if isinstance(node.value.left, ast.BinOp):
if isinstance(node.value.left.op, ast.Add):
self.features["binary_search_pattern"] = True
# -------- Two Pointer Detection --------
if node.targets and isinstance(node.targets[0], ast.Name):
var = node.targets[0].id
# Case 1: left = 0
if isinstance(node.value, (ast.Constant, ast.Num)):
self.features["pointer_variables"].add(var)
# Case 2: right = len(arr) - 1
if isinstance(node.value, ast.BinOp):
self.features["pointer_variables"].add(var)
# -------- BFS Detection --------
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
if node.value.func.id == "deque":
if node.targets and isinstance(node.targets[0], ast.Name):
var = node.targets[0].id
self.features["queue_variables"].add(var)
# ------- Detect stack initialization
if node.targets and isinstance(node.targets[0], ast.Name):
var = node.targets[0].id
if isinstance(node.value, (ast.List, ast.Call)):
if var.lower() == "stack":
self.features["uses_stack"] = True
# ------- Detect memo dictionary initialization
if isinstance(node.value, ast.Dict):
if node.targets and isinstance(node.targets[0], ast.Name):
var = node.targets[0].id.lower()
if var in ("memo", "cache", "dp"):
self.features["memo_dict_defined"] = True
# Detect memo[n] = ...
if node.targets and isinstance(node.targets[0], ast.Subscript):
target = node.targets[0]
if isinstance(target.value, ast.Name):
var = target.value.id.lower()
if var in ("memo", "cache", "dp"):
self.features["memo_store_detected"] = True
# Detect 2D DP Tables
if isinstance(node.value, ast.ListComp):
self.features["dp_dimension"] = 2
if isinstance(node.value, ast.List):
if any(isinstance(el, ast.List) for el in node.value.elts):
self.features["dp_dimension"] = 2
# Detect true tabulation recurrence && 2D KNAPSACK FIX
if node.targets and isinstance(node.targets[0], ast.Subscript):
target = node.targets[0]
# Find base name
base = target.value
while isinstance(base, ast.Subscript):
base = base.value
if isinstance(base, ast.Name):
var = base.id.lower()
if var in ("dp", "memo", "cache"):
for child in ast.walk(node.value):
if isinstance(child, ast.Name) and child.id.lower() == var:
self.features["dp_self_dependency"] = True
# -------- Bubble Sort Adjacent Swap Detection --------
if (
isinstance(node.targets[0], ast.Tuple)
and isinstance(node.value, ast.Tuple)
and len(node.targets[0].elts) == 2
and len(node.value.elts) == 2
):
left = node.targets[0].elts
right = node.value.elts
if all(isinstance(el, ast.Subscript) for el in left + right):
self.features["adjacent_swap_detected"] = True
# -------- Insertion Sort Shift Detection --------
if node.targets and isinstance(node.targets[0], ast.Subscript):
target = node.targets[0]
if isinstance(node.value, ast.Subscript):
self.features["insertion_shift_detected"] = True
# Merge Sort
if node.targets and isinstance(node.targets[0], ast.Name):
var = node.targets[0].id.lower()
if var == "pivot":
self.features["quick_sort_pattern"] = True
self.generic_visit(node)
def visit_AugAssign(self, node):
if isinstance(node.target, ast.Name):
var = node.target.id
if isinstance(node.op, (ast.Add, ast.Sub)):
if var in self.features["pointer_variables"]:
self.features["pointer_updates"] += 1
if isinstance(node.target, ast.Name):
var = node.target.id.lower()
if var in ("left", "l", "start"):
self.features["window_shrinks"] += 1
self.generic_visit(node)
# ------ subscript access -----
def visit_Subscript(self, node):
# Walk up until we find base name
base = node.value
while isinstance(base, ast.Subscript):
base = base.value
if isinstance(base, ast.Name):
var = base.id.lower()
if var in ("dp", "memo", "cache"):
self.features["uses_dp_array"] = True
self.generic_visit(node)
def visit_Compare(self, node):
# Detect: X in memo/cache/dp
if any(isinstance(op, ast.In) for op in node.ops):
for comparator in node.comparators:
if isinstance(comparator, ast.Name):
if comparator.id.lower() in ("memo", "cache", "dp"):
self.features["memo_lookup_detected"] = True
self.generic_visit(node)
def extract_features(tree: ast.AST) -> dict:
extractor = FeatureExtractor()
extractor.visit(tree)
extractor.features["max_loop_depth"] = extractor.max_loop_depth
if (
extractor.features["while_loops"] >= 1
and extractor.features["queue_pop_front"]
and extractor.features["queue_append_detected"]
and extractor.features["graph_iteration"]
):
extractor.features["bfs_pattern"] = True
# High confidence memoization
if (
extractor.features["recursion"]
and extractor.features["memo_dict_defined"]
and extractor.features["memo_lookup_detected"]
and extractor.features["memo_store_detected"]
):
extractor.features["memoization_pattern"] = True
# Tabulation
if (
extractor.features["uses_dp_array"]
and extractor.features["dp_self_dependency"]
and extractor.features["for_loops"] >= 1
):
extractor.features["tabulation_pattern"] = True
# Final DP pattern
if (
extractor.features["memoization_pattern"]
or extractor.features["tabulation_pattern"]
):
extractor.features["dp_pattern"] = True
# Sorting detection
if extractor.features["max_loop_depth"] >= 2:
if extractor.features["adjacent_swap_detected"]:
extractor.features["bubble_sort_pattern"] = True
extractor.features["sorting_pattern"] = True
elif extractor.features["insertion_shift_detected"]:
extractor.features["insertion_sort_pattern"] = True
extractor.features["sorting_pattern"] = True
# Sliding Window heuristic
if (
extractor.features["for_loops"] >= 1
and extractor.features["while_loops"] >= 1
and extractor.features["window_updates"] >= 1
and extractor.features["window_shrinks"] >= 1
):
extractor.features["sliding_window_pattern"] = True
# Heap pattern detection
if (
extractor.features["heap_imported"]
and extractor.features["heap_operations"] >= 1
):
extractor.features["heap_pattern"] = True
return extractor.features