| | import gradio as gr |
| | import torch |
| | import numpy as np |
| |
|
| | from sentence_transformers import SentenceTransformer, util |
| |
|
| | |
| | |
| | model_name = "juanwisz/modernbert-python-code-retrieval" |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | embedding_model = SentenceTransformer(model_name, device=device) |
| |
|
| | |
| | |
| | |
| | |
| | def retrieve_top_snippets(query, code_input): |
| | |
| | |
| | snippets = [s.strip() for s in code_input.split("---") if s.strip()] |
| |
|
| | |
| | if len(snippets) == 0: |
| | return "No code snippets detected (make sure to separate them with ---)." |
| |
|
| | |
| | query_emb = embedding_model.encode(query, convert_to_tensor=True) |
| | snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True) |
| |
|
| | |
| | cos_scores = util.cos_sim(query_emb, snippets_emb)[0] |
| |
|
| | |
| | |
| | top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices |
| |
|
| | |
| | results = [] |
| | for idx in top_indices: |
| | score = cos_scores[idx].item() |
| | snippet_text = snippets[idx] |
| | results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```") |
| |
|
| | |
| | return "\n\n".join(results) |
| |
|
| |
|
| | |
| | |
| | |
| | css = """ |
| | #container { |
| | margin: 0 auto; |
| | max-width: 700px; |
| | } |
| | """ |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | gr.Markdown("# Code Retrieval using ModernBERT\n" |
| | "Enter a natural language query and paste multiple Python code snippets, " |
| | "delimited by `---`. We'll return the top 3 matches.") |
| |
|
| | with gr.Column(elem_id="container"): |
| | with gr.Row(): |
| | query_input = gr.Textbox( |
| | label="Natural Language Query", |
| | placeholder="What does your function do? e.g., 'Parse JSON from a string'" |
| | ) |
| |
|
| | code_snippets_input = gr.Textbox( |
| | label="Paste Python functions (delimited by ---)", |
| | lines=10, |
| | placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---" |
| | ) |
| |
|
| | search_btn = gr.Button("Search", variant="primary") |
| | results_output = gr.Markdown(label="Top 3 Matches") |
| |
|
| | |
| | search_btn.click( |
| | fn=retrieve_top_snippets, |
| | inputs=[query_input, code_snippets_input], |
| | outputs=results_output |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|