import ast class FeatureExtractor(ast.NodeVisitor): """ Traverses the AST and extracts structural features from Python source code. """ def __init__(self): self.features = { "for_loops": 0, "while_loops": 0, "function_calls": set(), "recursion": False, "max_loop_depth": 0, "recursive_call_count": 0, "divide_and_conquer": False, "binary_search_pattern": False, "pointer_variables": set(), "pointer_updates": 0, "bfs_pattern": False, "queue_variables": set(), "queue_operations": 0, "queue_pop_front": False, "queue_append_detected": False, "graph_iteration": False, "dfs_pattern": False, "uses_stack": False, "uses_pop": False, "dp_pattern": False, "uses_dp_array": False, "sorting_pattern": False, "bubble_sort_pattern": False, "insertion_sort_pattern": False, "adjacent_swap_detected": False, "insertion_shift_detected": False, "memoization_pattern": False, "memo_dict_defined": False, "memo_lookup_detected": False, "memo_store_detected": False, "tabulation_pattern": False, "dp_self_dependency": False, "dp_dimension": 1, "merge_sort_pattern": False, "quick_sort_pattern": False, "sliding_window_pattern": False, "window_updates": 0, "window_shrinks": 0, "heap_imported": False, "heap_operations": 0, "heap_pattern": False, } self.current_function = None self.current_loop_depth = 0 self.max_loop_depth = 0 self.current_function_name = None def visit_Import(self, node): for alias in node.names: if alias.name == "heapq": self.features["heap_imported"] = True self.generic_visit(node) def visit_ImportFrom(self, node): if node.module == "heapq": self.features["heap_imported"] = True self.generic_visit(node) def visit_FunctionDef(self, node): previous_function = self.current_function_name self.current_function_name = node.name self.generic_visit(node) self.current_function_name = previous_function def visit_For(self, node): self.features["for_loops"] += 1 self.current_loop_depth += 1 self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth) # Detect graph[node] iteration if isinstance(node.iter, ast.Subscript): self.features["graph_iteration"] = True self.generic_visit(node) self.current_loop_depth -= 1 if isinstance(node.target, ast.Name): var = node.target.id.lower() if var in ("right", "r", "end"): self.features["window_updates"] += 1 def visit_While(self, node): self.features["while_loops"] += 1 self.current_loop_depth += 1 self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth) self.generic_visit(node) self.current_loop_depth -= 1 def visit_Call(self, node): if isinstance(node.func, ast.Name): function_name = node.func.id self.features["function_calls"].add(function_name) # Detect recursion if function_name == self.current_function_name: self.features["recursion"] = True self.features["recursive_call_count"] += 1 # If recursion + loop present → DFS-style if self.features["for_loops"] >= 1: self.features["dfs_pattern"] = True # Detect divide-and-conquer for arg in node.args: # Case 1: n // 2 or n / 2 if isinstance(arg, ast.BinOp) and isinstance(arg.op, (ast.FloorDiv, ast.Div)): self.features["divide_and_conquer"] = True # Case 2: slicing like arr[:mid] if isinstance(arg, ast.Subscript): if isinstance(arg.slice, ast.Slice): self.features["divide_and_conquer"] = True if isinstance(arg, ast.Subscript) and isinstance(arg.slice, ast.Slice): self.features["merge_sort_pattern"] = True # Detect queue operations if isinstance(node.func, ast.Attribute): if isinstance(node.func.value, ast.Name): var = node.func.value.id if var in self.features["queue_variables"]: if node.func.attr in ("append", "popleft"): self.features["queue_operations"] += 1 # Detect stack.pop() or queue.pop() if isinstance(node.func, ast.Attribute): method = node.func.attr if method == "pop": self.features["uses_pop"] = True if method == "append": # mark append usage pass # Detect pop(0) for list-based BFS if isinstance(node.func, ast.Attribute): method = node.func.attr # pop(0) if method == "pop": if node.args and isinstance(node.args[0], ast.Constant): if node.args[0].value == 0: self.features["queue_pop_front"] = True # append() if method == "append": self.features["queue_append_detected"] = True # popleft() if method == "popleft": self.features["queue_pop_front"] = True # Iterative DFS heuristic if ( self.features["uses_stack"] and self.features["uses_pop"] and self.features["for_loops"] >= 1 ): self.features["dfs_pattern"] = True # Heap operations if isinstance(node.func, ast.Attribute): if isinstance(node.func.value, ast.Name): if node.func.value.id == "heapq": if node.func.attr in ("heappush", "heappop", "heapify"): self.features["heap_operations"] += 1 self.generic_visit(node) def visit_Assign(self, node): # -------- Binary Search Pattern Detection -------- if isinstance(node.value, ast.BinOp): if isinstance(node.value.op, ast.FloorDiv): if isinstance(node.value.left, ast.BinOp): if isinstance(node.value.left.op, ast.Add): self.features["binary_search_pattern"] = True # -------- Two Pointer Detection -------- if node.targets and isinstance(node.targets[0], ast.Name): var = node.targets[0].id # Case 1: left = 0 if isinstance(node.value, (ast.Constant, ast.Num)): self.features["pointer_variables"].add(var) # Case 2: right = len(arr) - 1 if isinstance(node.value, ast.BinOp): self.features["pointer_variables"].add(var) # -------- BFS Detection -------- if isinstance(node.value, ast.Call): if isinstance(node.value.func, ast.Name): if node.value.func.id == "deque": if node.targets and isinstance(node.targets[0], ast.Name): var = node.targets[0].id self.features["queue_variables"].add(var) # ------- Detect stack initialization if node.targets and isinstance(node.targets[0], ast.Name): var = node.targets[0].id if isinstance(node.value, (ast.List, ast.Call)): if var.lower() == "stack": self.features["uses_stack"] = True # ------- Detect memo dictionary initialization if isinstance(node.value, ast.Dict): if node.targets and isinstance(node.targets[0], ast.Name): var = node.targets[0].id.lower() if var in ("memo", "cache", "dp"): self.features["memo_dict_defined"] = True # Detect memo[n] = ... if node.targets and isinstance(node.targets[0], ast.Subscript): target = node.targets[0] if isinstance(target.value, ast.Name): var = target.value.id.lower() if var in ("memo", "cache", "dp"): self.features["memo_store_detected"] = True # Detect 2D DP Tables if isinstance(node.value, ast.ListComp): self.features["dp_dimension"] = 2 if isinstance(node.value, ast.List): if any(isinstance(el, ast.List) for el in node.value.elts): self.features["dp_dimension"] = 2 # Detect true tabulation recurrence && 2D KNAPSACK FIX if node.targets and isinstance(node.targets[0], ast.Subscript): target = node.targets[0] # Find base name base = target.value while isinstance(base, ast.Subscript): base = base.value if isinstance(base, ast.Name): var = base.id.lower() if var in ("dp", "memo", "cache"): for child in ast.walk(node.value): if isinstance(child, ast.Name) and child.id.lower() == var: self.features["dp_self_dependency"] = True # -------- Bubble Sort Adjacent Swap Detection -------- if ( isinstance(node.targets[0], ast.Tuple) and isinstance(node.value, ast.Tuple) and len(node.targets[0].elts) == 2 and len(node.value.elts) == 2 ): left = node.targets[0].elts right = node.value.elts if all(isinstance(el, ast.Subscript) for el in left + right): self.features["adjacent_swap_detected"] = True # -------- Insertion Sort Shift Detection -------- if node.targets and isinstance(node.targets[0], ast.Subscript): target = node.targets[0] if isinstance(node.value, ast.Subscript): self.features["insertion_shift_detected"] = True # Merge Sort if node.targets and isinstance(node.targets[0], ast.Name): var = node.targets[0].id.lower() if var == "pivot": self.features["quick_sort_pattern"] = True self.generic_visit(node) def visit_AugAssign(self, node): if isinstance(node.target, ast.Name): var = node.target.id if isinstance(node.op, (ast.Add, ast.Sub)): if var in self.features["pointer_variables"]: self.features["pointer_updates"] += 1 if isinstance(node.target, ast.Name): var = node.target.id.lower() if var in ("left", "l", "start"): self.features["window_shrinks"] += 1 self.generic_visit(node) # ------ subscript access ----- def visit_Subscript(self, node): # Walk up until we find base name base = node.value while isinstance(base, ast.Subscript): base = base.value if isinstance(base, ast.Name): var = base.id.lower() if var in ("dp", "memo", "cache"): self.features["uses_dp_array"] = True self.generic_visit(node) def visit_Compare(self, node): # Detect: X in memo/cache/dp if any(isinstance(op, ast.In) for op in node.ops): for comparator in node.comparators: if isinstance(comparator, ast.Name): if comparator.id.lower() in ("memo", "cache", "dp"): self.features["memo_lookup_detected"] = True self.generic_visit(node) def extract_features(tree: ast.AST) -> dict: extractor = FeatureExtractor() extractor.visit(tree) extractor.features["max_loop_depth"] = extractor.max_loop_depth if ( extractor.features["while_loops"] >= 1 and extractor.features["queue_pop_front"] and extractor.features["queue_append_detected"] and extractor.features["graph_iteration"] ): extractor.features["bfs_pattern"] = True # High confidence memoization if ( extractor.features["recursion"] and extractor.features["memo_dict_defined"] and extractor.features["memo_lookup_detected"] and extractor.features["memo_store_detected"] ): extractor.features["memoization_pattern"] = True # Tabulation if ( extractor.features["uses_dp_array"] and extractor.features["dp_self_dependency"] and extractor.features["for_loops"] >= 1 ): extractor.features["tabulation_pattern"] = True # Final DP pattern if ( extractor.features["memoization_pattern"] or extractor.features["tabulation_pattern"] ): extractor.features["dp_pattern"] = True # Sorting detection if extractor.features["max_loop_depth"] >= 2: if extractor.features["adjacent_swap_detected"]: extractor.features["bubble_sort_pattern"] = True extractor.features["sorting_pattern"] = True elif extractor.features["insertion_shift_detected"]: extractor.features["insertion_sort_pattern"] = True extractor.features["sorting_pattern"] = True # Sliding Window heuristic if ( extractor.features["for_loops"] >= 1 and extractor.features["while_loops"] >= 1 and extractor.features["window_updates"] >= 1 and extractor.features["window_shrinks"] >= 1 ): extractor.features["sliding_window_pattern"] = True # Heap pattern detection if ( extractor.features["heap_imported"] and extractor.features["heap_operations"] >= 1 ): extractor.features["heap_pattern"] = True return extractor.features