| --- |
| license: mit |
| language: |
| - ru |
| tags: |
| - Prompt |
| - Prompt_Classification |
| - Classification |
| - LinearModel |
| - MLP |
| - class |
| - Prompt Classes |
| - Classificator |
| - Prompt Classification |
| - AI |
| - ML |
| - Class |
| - Classify |
| - Text |
| - Context |
| --- |
| |
| # 🔁 SimplePromptClassifier — классификатор промптов (русский) |
|
|
|  |
|
|
|
|
| **Кратко:** модель классифицирует входные промпты/вопросы на три действия: |
| - **0 — Поиск в локальной базе знаний (RAG)**: сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации. |
| - **1 — Поиск в сети**: триггер запуска обхода внешних поисковых систем/скрейпинга. |
| - **2 — Прямой запрос**: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа. |
|
|
| --- |
|
|
| ## Где используется |
| Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта: |
| - чат-боты со связкой Retrieval-Augmented Generation (RAG), |
| - голосовые ассистенты, |
| - интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией. |
|
|
| --- |
|
|
| ## Файлы в репозитории |
| - `pytorch_model.bin` — веса модели (state_dict). |
| - `config.json` — конфигурация (input_dim, num_classes, p_dropout, classes). |
| - `modeling_simple_classifier.py` — определение архитектуры. |
| - `vectorizer.pkl` — sklearn-векторизатор (TF-IDF/Count). |
| - `svd.pkl` — TruncatedSVD (опционально). |
| - `label_encoder.pkl` — sklearn.LabelEncoder (для декодирования метки). |
| - `README.md` — эта карточка. |
|
|
| --- |
|
|
| ## Пример загрузки и инференса (без AutoModel) |
|
|
| ```python |
| # Пример: загрузка напрямую из репозитория HF (не требует локальной копии) |
| from huggingface_hub import hf_hub_download |
| import json, pickle, torch |
| import numpy as np |
| from types import SimpleNamespace |
| |
| REPO = "Neweret/SimplePromptClassifier-85k" |
| |
| config_path = hf_hub_download(REPO, "config.json") |
| weights_path = hf_hub_download(REPO, "pytorch_model.bin") |
| vec_path = hf_hub_download(REPO, "vectorizer.pkl") |
| svd_path = None |
| try: |
| svd_path = hf_hub_download(REPO, "svd.pkl") |
| except Exception: |
| svd_path = None |
| le_path = hf_hub_download(REPO, "label_encoder.pkl") |
| |
| cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8"))) |
| |
| # --- Динамическая модель --- |
| class SimpleClassifier(torch.nn.Module): |
| def __init__(self, input_dim, num_classes, p_dropout=0.3): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(input_dim, 256) |
| self.ln1 = torch.nn.LayerNorm(256) |
| self.dropout = torch.nn.Dropout(p_dropout) |
| self.linear2 = torch.nn.Linear(256, 128) |
| self.ln2 = torch.nn.LayerNorm(128) |
| self.linear_out = torch.nn.Linear(128, num_classes) |
| def forward(self, x): |
| x = torch.nn.functional.gelu(self.ln1(self.linear1(x))) |
| x = self.dropout(x) |
| x = torch.nn.functional.gelu(self.ln2(self.linear2(x))) |
| x = self.dropout(x) |
| return self.linear_out(x) |
| |
| model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout) |
| state = torch.load(weights_path, map_location="cpu") |
| model.load_state_dict(state) |
| model.eval() |
| |
| # препроцессинг |
| vectorizer = pickle.load(open(vec_path, "rb")) |
| svd = pickle.load(open(svd_path, "rb")) if svd_path else None |
| le = pickle.load(open(le_path, "rb")) |
| |
| def preprocess(text): |
| X = vectorizer.transform([text]) |
| if svd is not None: |
| X = svd.transform(X) |
| return X.astype(np.float32) |
| |
| def predict(text): |
| x = preprocess(text) |
| xb = torch.from_numpy(x).float() |
| with torch.inference_mode(): |
| logits = model(xb) |
| pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0]) |
| return pred, le.inverse_transform([pred])[0] |
| |
| # пример |
| print(predict("Как мне найти документацию по нашей компании?")) |