| import ast
|
|
|
| class FeatureExtractor(ast.NodeVisitor):
|
| """
|
| Traverses the AST and extracts structural features
|
| from Python source code.
|
| """
|
|
|
| def __init__(self):
|
|
|
| self.features = {
|
| "for_loops": 0,
|
| "while_loops": 0,
|
| "function_calls": set(),
|
| "recursion": False,
|
| "max_loop_depth": 0,
|
| "recursive_call_count": 0,
|
| "divide_and_conquer": False,
|
| "binary_search_pattern": False,
|
|
|
| "pointer_variables": set(),
|
| "pointer_updates": 0,
|
|
|
| "bfs_pattern": False,
|
| "queue_variables": set(),
|
| "queue_operations": 0,
|
| "queue_pop_front": False,
|
| "queue_append_detected": False,
|
| "graph_iteration": False,
|
|
|
| "dfs_pattern": False,
|
| "uses_stack": False,
|
| "uses_pop": False,
|
|
|
| "dp_pattern": False,
|
| "uses_dp_array": False,
|
|
|
| "sorting_pattern": False,
|
| "bubble_sort_pattern": False,
|
| "insertion_sort_pattern": False,
|
| "adjacent_swap_detected": False,
|
| "insertion_shift_detected": False,
|
|
|
| "memoization_pattern": False,
|
| "memo_dict_defined": False,
|
| "memo_lookup_detected": False,
|
| "memo_store_detected": False,
|
|
|
| "tabulation_pattern": False,
|
| "dp_self_dependency": False,
|
| "dp_dimension": 1,
|
|
|
| "merge_sort_pattern": False,
|
| "quick_sort_pattern": False,
|
|
|
| "sliding_window_pattern": False,
|
| "window_updates": 0,
|
| "window_shrinks": 0,
|
|
|
| "heap_imported": False,
|
| "heap_operations": 0,
|
| "heap_pattern": False,
|
| }
|
| self.current_function = None
|
|
|
| self.current_loop_depth = 0
|
| self.max_loop_depth = 0
|
|
|
| self.current_function_name = None
|
|
|
| def visit_Import(self, node):
|
| for alias in node.names:
|
| if alias.name == "heapq":
|
| self.features["heap_imported"] = True
|
| self.generic_visit(node)
|
|
|
| def visit_ImportFrom(self, node):
|
| if node.module == "heapq":
|
| self.features["heap_imported"] = True
|
| self.generic_visit(node)
|
|
|
| def visit_FunctionDef(self, node):
|
| previous_function = self.current_function_name
|
| self.current_function_name = node.name
|
|
|
| self.generic_visit(node)
|
|
|
| self.current_function_name = previous_function
|
|
|
| def visit_For(self, node):
|
| self.features["for_loops"] += 1
|
|
|
| self.current_loop_depth += 1
|
| self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
|
|
|
|
|
| if isinstance(node.iter, ast.Subscript):
|
| self.features["graph_iteration"] = True
|
|
|
| self.generic_visit(node)
|
| self.current_loop_depth -= 1
|
|
|
| if isinstance(node.target, ast.Name):
|
| var = node.target.id.lower()
|
| if var in ("right", "r", "end"):
|
| self.features["window_updates"] += 1
|
|
|
| def visit_While(self, node):
|
| self.features["while_loops"] += 1
|
|
|
| self.current_loop_depth += 1
|
| self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
|
|
|
| self.generic_visit(node)
|
| self.current_loop_depth -= 1
|
|
|
| def visit_Call(self, node):
|
| if isinstance(node.func, ast.Name):
|
| function_name = node.func.id
|
|
|
| self.features["function_calls"].add(function_name)
|
|
|
|
|
| if function_name == self.current_function_name:
|
| self.features["recursion"] = True
|
| self.features["recursive_call_count"] += 1
|
|
|
|
|
| if self.features["for_loops"] >= 1:
|
| self.features["dfs_pattern"] = True
|
|
|
|
|
| for arg in node.args:
|
|
|
|
|
| if isinstance(arg, ast.BinOp) and isinstance(arg.op, (ast.FloorDiv, ast.Div)):
|
| self.features["divide_and_conquer"] = True
|
|
|
|
|
| if isinstance(arg, ast.Subscript):
|
| if isinstance(arg.slice, ast.Slice):
|
| self.features["divide_and_conquer"] = True
|
| if isinstance(arg, ast.Subscript) and isinstance(arg.slice, ast.Slice):
|
| self.features["merge_sort_pattern"] = True
|
|
|
|
|
| if isinstance(node.func, ast.Attribute):
|
| if isinstance(node.func.value, ast.Name):
|
| var = node.func.value.id
|
|
|
| if var in self.features["queue_variables"]:
|
| if node.func.attr in ("append", "popleft"):
|
| self.features["queue_operations"] += 1
|
|
|
|
|
| if isinstance(node.func, ast.Attribute):
|
| method = node.func.attr
|
|
|
| if method == "pop":
|
| self.features["uses_pop"] = True
|
|
|
| if method == "append":
|
|
|
| pass
|
|
|
|
|
| if isinstance(node.func, ast.Attribute):
|
| method = node.func.attr
|
|
|
|
|
| if method == "pop":
|
| if node.args and isinstance(node.args[0], ast.Constant):
|
| if node.args[0].value == 0:
|
| self.features["queue_pop_front"] = True
|
|
|
|
|
| if method == "append":
|
| self.features["queue_append_detected"] = True
|
|
|
|
|
| if method == "popleft":
|
| self.features["queue_pop_front"] = True
|
|
|
|
|
|
|
| if (
|
| self.features["uses_stack"]
|
| and self.features["uses_pop"]
|
| and self.features["for_loops"] >= 1
|
| ):
|
| self.features["dfs_pattern"] = True
|
|
|
|
|
| if isinstance(node.func, ast.Attribute):
|
| if isinstance(node.func.value, ast.Name):
|
| if node.func.value.id == "heapq":
|
| if node.func.attr in ("heappush", "heappop", "heapify"):
|
| self.features["heap_operations"] += 1
|
|
|
| self.generic_visit(node)
|
|
|
| def visit_Assign(self, node):
|
|
|
|
|
| if isinstance(node.value, ast.BinOp):
|
| if isinstance(node.value.op, ast.FloorDiv):
|
| if isinstance(node.value.left, ast.BinOp):
|
| if isinstance(node.value.left.op, ast.Add):
|
| self.features["binary_search_pattern"] = True
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Name):
|
| var = node.targets[0].id
|
|
|
|
|
| if isinstance(node.value, (ast.Constant, ast.Num)):
|
| self.features["pointer_variables"].add(var)
|
|
|
|
|
| if isinstance(node.value, ast.BinOp):
|
| self.features["pointer_variables"].add(var)
|
|
|
|
|
| if isinstance(node.value, ast.Call):
|
| if isinstance(node.value.func, ast.Name):
|
| if node.value.func.id == "deque":
|
| if node.targets and isinstance(node.targets[0], ast.Name):
|
| var = node.targets[0].id
|
| self.features["queue_variables"].add(var)
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Name):
|
| var = node.targets[0].id
|
|
|
| if isinstance(node.value, (ast.List, ast.Call)):
|
| if var.lower() == "stack":
|
| self.features["uses_stack"] = True
|
|
|
|
|
| if isinstance(node.value, ast.Dict):
|
| if node.targets and isinstance(node.targets[0], ast.Name):
|
| var = node.targets[0].id.lower()
|
| if var in ("memo", "cache", "dp"):
|
| self.features["memo_dict_defined"] = True
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Subscript):
|
| target = node.targets[0]
|
|
|
| if isinstance(target.value, ast.Name):
|
| var = target.value.id.lower()
|
| if var in ("memo", "cache", "dp"):
|
| self.features["memo_store_detected"] = True
|
|
|
|
|
| if isinstance(node.value, ast.ListComp):
|
| self.features["dp_dimension"] = 2
|
|
|
| if isinstance(node.value, ast.List):
|
| if any(isinstance(el, ast.List) for el in node.value.elts):
|
| self.features["dp_dimension"] = 2
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Subscript):
|
| target = node.targets[0]
|
|
|
|
|
| base = target.value
|
| while isinstance(base, ast.Subscript):
|
| base = base.value
|
|
|
| if isinstance(base, ast.Name):
|
| var = base.id.lower()
|
|
|
| if var in ("dp", "memo", "cache"):
|
| for child in ast.walk(node.value):
|
| if isinstance(child, ast.Name) and child.id.lower() == var:
|
| self.features["dp_self_dependency"] = True
|
|
|
|
|
| if (
|
| isinstance(node.targets[0], ast.Tuple)
|
| and isinstance(node.value, ast.Tuple)
|
| and len(node.targets[0].elts) == 2
|
| and len(node.value.elts) == 2
|
| ):
|
| left = node.targets[0].elts
|
| right = node.value.elts
|
|
|
| if all(isinstance(el, ast.Subscript) for el in left + right):
|
| self.features["adjacent_swap_detected"] = True
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Subscript):
|
| target = node.targets[0]
|
|
|
| if isinstance(node.value, ast.Subscript):
|
| self.features["insertion_shift_detected"] = True
|
|
|
|
|
| if node.targets and isinstance(node.targets[0], ast.Name):
|
| var = node.targets[0].id.lower()
|
| if var == "pivot":
|
| self.features["quick_sort_pattern"] = True
|
|
|
| self.generic_visit(node)
|
|
|
|
|
| def visit_AugAssign(self, node):
|
| if isinstance(node.target, ast.Name):
|
| var = node.target.id
|
|
|
| if isinstance(node.op, (ast.Add, ast.Sub)):
|
| if var in self.features["pointer_variables"]:
|
| self.features["pointer_updates"] += 1
|
|
|
| if isinstance(node.target, ast.Name):
|
| var = node.target.id.lower()
|
|
|
| if var in ("left", "l", "start"):
|
| self.features["window_shrinks"] += 1
|
|
|
| self.generic_visit(node)
|
|
|
|
|
| def visit_Subscript(self, node):
|
|
|
| base = node.value
|
| while isinstance(base, ast.Subscript):
|
| base = base.value
|
|
|
| if isinstance(base, ast.Name):
|
| var = base.id.lower()
|
|
|
| if var in ("dp", "memo", "cache"):
|
| self.features["uses_dp_array"] = True
|
|
|
| self.generic_visit(node)
|
|
|
|
|
| def visit_Compare(self, node):
|
|
|
| if any(isinstance(op, ast.In) for op in node.ops):
|
| for comparator in node.comparators:
|
| if isinstance(comparator, ast.Name):
|
| if comparator.id.lower() in ("memo", "cache", "dp"):
|
| self.features["memo_lookup_detected"] = True
|
|
|
| self.generic_visit(node)
|
|
|
|
|
| def extract_features(tree: ast.AST) -> dict:
|
| extractor = FeatureExtractor()
|
| extractor.visit(tree)
|
|
|
| extractor.features["max_loop_depth"] = extractor.max_loop_depth
|
|
|
| if (
|
| extractor.features["while_loops"] >= 1
|
| and extractor.features["queue_pop_front"]
|
| and extractor.features["queue_append_detected"]
|
| and extractor.features["graph_iteration"]
|
| ):
|
| extractor.features["bfs_pattern"] = True
|
|
|
|
|
| if (
|
| extractor.features["recursion"]
|
| and extractor.features["memo_dict_defined"]
|
| and extractor.features["memo_lookup_detected"]
|
| and extractor.features["memo_store_detected"]
|
| ):
|
| extractor.features["memoization_pattern"] = True
|
|
|
|
|
| if (
|
| extractor.features["uses_dp_array"]
|
| and extractor.features["dp_self_dependency"]
|
| and extractor.features["for_loops"] >= 1
|
| ):
|
| extractor.features["tabulation_pattern"] = True
|
|
|
|
|
| if (
|
| extractor.features["memoization_pattern"]
|
| or extractor.features["tabulation_pattern"]
|
| ):
|
| extractor.features["dp_pattern"] = True
|
|
|
|
|
| if extractor.features["max_loop_depth"] >= 2:
|
| if extractor.features["adjacent_swap_detected"]:
|
| extractor.features["bubble_sort_pattern"] = True
|
| extractor.features["sorting_pattern"] = True
|
|
|
| elif extractor.features["insertion_shift_detected"]:
|
| extractor.features["insertion_sort_pattern"] = True
|
| extractor.features["sorting_pattern"] = True
|
|
|
|
|
| if (
|
| extractor.features["for_loops"] >= 1
|
| and extractor.features["while_loops"] >= 1
|
| and extractor.features["window_updates"] >= 1
|
| and extractor.features["window_shrinks"] >= 1
|
| ):
|
| extractor.features["sliding_window_pattern"] = True
|
|
|
|
|
| if (
|
| extractor.features["heap_imported"]
|
| and extractor.features["heap_operations"] >= 1
|
| ):
|
| extractor.features["heap_pattern"] = True
|
|
|
| return extractor.features
|
|
|
|
|