Cortex-Coder / scripts /main.py
junaid17's picture
Upload 7 files
eff58d2 verified
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from scripts.load_model import get_model
from typing import TypedDict, Literal, Optional, Annotated
from langgraph.graph import StateGraph, END
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from scripts.prompts import (
GENERAL_CHAT_SYSTEM_PROMPT,
CODE_AGENT_SYSTEM_PROMPT,
REPAIR_SYSTEM_PROMPT
)
# =========================================================
# LLM
# =========================================================
llm = get_model()
# =========================================================
# STATE
# =========================================================
class ChatState(TypedDict):
user_input: Annotated[list[BaseMessage], add_messages]
intent: Optional[Literal["coding", "general_chat"]]
task_type: Optional[Literal["code_generation", "debugging", "explanation"]]
language: Optional[Literal["python", "javascript", "cpp", "sql", "java", "unknown"]]
generated_output: Optional[str]
is_code: Optional[bool]
retry_count: int
# =========================================================
# HELPERS
# =========================================================
def get_last_user_message(state: ChatState) -> str:
messages = state["user_input"]
if not messages:
return ""
return messages[-1].content
# =========================================================
# INTENT CLASSIFIER
# =========================================================
def classify_intent(state: ChatState):
query = get_last_user_message(state).lower()
coding_keywords = [
"code", "python", "javascript", "java", "c++", "sql",
"bug", "debug", "fix", "function", "algorithm",
"class", "implement", "build", "api", "query",
"optimize", "refactor", "script", "program"
]
intent = "coding" if any(word in query for word in coding_keywords) else "general_chat"
return {"intent": intent}
def route_intent(state: ChatState):
if state["intent"] == "general_chat":
return "general_chat"
return "coding"
# =========================================================
# GENERAL CHAT
# =========================================================
def handle_general_chat(state: ChatState):
query = get_last_user_message(state)
messages = [
SystemMessage(content=GENERAL_CHAT_SYSTEM_PROMPT),
HumanMessage(content=query)
]
response = llm.invoke(messages)
return {
"generated_output": response.content
}
# =========================================================
# TASK CLASSIFIER
# =========================================================
def classify_task(state: ChatState):
query = get_last_user_message(state).lower()
debugging_keywords = [
"fix", "debug", "error", "broken",
"issue", "not working", "exception", "traceback"
]
explanation_keywords = [
"explain", "what is", "how does", "why", "difference between"
]
if any(word in query for word in debugging_keywords):
task = "debugging"
elif any(word in query for word in explanation_keywords):
task = "explanation"
else:
task = "code_generation"
return {"task_type": task}
# =========================================================
# LANGUAGE DETECTOR
# =========================================================
def detect_language(state: ChatState):
query = get_last_user_message(state).lower()
if any(k in query for k in ["python", "def ", "import ", "print("]):
lang = "python"
elif any(k in query for k in ["javascript", "js", "function ", "console.log", "const ", "let "]):
lang = "javascript"
elif any(k in query for k in ["c++", "#include", "std::", "cout"]):
lang = "cpp"
elif any(k in query for k in ["java", "public class", "system.out"]):
lang = "java"
elif any(k in query for k in ["sql", "select ", "insert ", "update ", "delete ", "join "]):
lang = "sql"
else:
lang = "unknown"
return {"language": lang}
# =========================================================
# GENERATOR
# =========================================================
def generate_code(state: ChatState):
query = get_last_user_message(state)
task = state["task_type"]
language = state["language"]
messages = [
SystemMessage(content=CODE_AGENT_SYSTEM_PROMPT),
HumanMessage(content=f"""
Task Type: {task}
Requested Language: {language}
User Request:
{query}
""")
]
response = llm.invoke(messages)
return {
"generated_output": response.content
}
# =========================================================
# OUTPUT CLASSIFIER
# =========================================================
def classify_output(state: ChatState):
output = (state["generated_output"] or "").lower()
code_markers = [
"def ", "class ", "function ", "const ", "let ",
"#include", "std::", "public class", "system.out",
"select ", "insert ", "update ", "delete ",
"console.log", "print(", "fn "
]
is_code = any(marker in output for marker in code_markers)
return {
"is_code": is_code
}
# =========================================================
# ROUTER
# =========================================================
def route_output(state: ChatState):
if state["task_type"] == "explanation":
return "final"
if state["is_code"]:
return "final"
if state["retry_count"] >= 2:
return "final"
return "repair"
# =========================================================
# REPAIR
# =========================================================
def repair_code(state: ChatState):
bad_output = state["generated_output"] or ""
language = state["language"]
messages = [
SystemMessage(content=REPAIR_SYSTEM_PROMPT),
HumanMessage(content=f"""
The previous response was expected to be executable code but was not.
Requested language:
{language}
Bad output:
{bad_output}
Return ONLY executable code.
No explanation.
No markdown.
""")
]
response = llm.invoke(messages)
return {
"generated_output": response.content,
"retry_count": state["retry_count"] + 1
}
# =========================================================
# GRAPH
# =========================================================
checkpointer = MemorySaver()
builder = StateGraph(ChatState)
builder.add_node("classify_intent", classify_intent)
builder.add_node("general_chat", handle_general_chat)
builder.add_node("classify_task", classify_task)
builder.add_node("detect_language", detect_language)
builder.add_node("generate_code", generate_code)
builder.add_node("classify_output", classify_output)
builder.add_node("repair", repair_code)
builder.set_entry_point("classify_intent")
builder.add_conditional_edges(
"classify_intent",
route_intent,
{
"general_chat": "general_chat",
"coding": "classify_task"
}
)
builder.add_edge("general_chat", END)
builder.add_edge("classify_task", "detect_language")
builder.add_edge("detect_language", "generate_code")
builder.add_edge("generate_code", "classify_output")
builder.add_conditional_edges(
"classify_output",
route_output,
{
"final": END,
"repair": "repair"
}
)
builder.add_edge("repair", "classify_output")
graph = builder.compile(checkpointer=checkpointer)
# =========================================================
# STREAM
# =========================================================
def stream_chat_response(user_message: str, thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
initial_state = {
"user_input": [HumanMessage(content=user_message)],
"intent": None,
"task_type": None,
"language": None,
"generated_output": None,
"is_code": None,
"retry_count": 0
}
for chunk in graph.stream(
initial_state,
config=config,
stream_mode="messages"
):
yield chunk
def get_chat_metadata(thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
state = graph.get_state(config)
values = state.values
return {
"intent": values.get("intent"),
"task_type": values.get("task_type"),
"language": values.get("language"),
"is_code": values.get("is_code"),
"retry_count": values.get("retry_count")
}