Spaces:
Running
Running
| import argparse | |
| import contextlib | |
| import io | |
| import logging | |
| import os | |
| import sys | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| from transformers.utils import logging as transformers_logging | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| PROJECT_CANDIDATES = [ | |
| SCRIPT_DIR, | |
| os.path.dirname(SCRIPT_DIR), | |
| os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"), | |
| ] | |
| PROJECT_DIR = None | |
| for candidate in PROJECT_CANDIDATES: | |
| if os.path.exists(os.path.join(candidate, "llmprop_model.py")): | |
| PROJECT_DIR = candidate | |
| break | |
| if PROJECT_DIR is None: | |
| raise FileNotFoundError( | |
| "Could not locate project root containing llmprop_model.py. " | |
| "Expected near the deployment folder." | |
| ) | |
| if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path: | |
| sys.path.insert(0, PROJECT_DIR) | |
| from llmprop_model import T5Predictor | |
| def z_denormalize(scaled_labels, labels_mean, labels_std): | |
| return (scaled_labels * labels_std) + labels_mean | |
| # ------------------------- | |
| # CONFIG | |
| # ------------------------- | |
| MODEL_PATH = os.path.join( | |
| PROJECT_DIR, | |
| "checkpoints", | |
| "samples", | |
| "regression", | |
| "best_checkpoint_for_fepa.pt", | |
| ) | |
| TOKENIZER_PATH = os.path.join( | |
| PROJECT_DIR, | |
| "tokenizers", | |
| "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge", | |
| ) | |
| TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv") | |
| PROPERTY_NAME = "formation_energy_per_atom" | |
| DEVICE = torch.device("cpu") | |
| # Silence HF/Transformers startup logs for cleaner terminal output. | |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") | |
| os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") | |
| transformers_logging.set_verbosity_error() | |
| logging.getLogger("huggingface_hub").setLevel(logging.ERROR) | |
| # ------------------------- | |
| # PATH CHECKS | |
| # ------------------------- | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}") | |
| if not os.path.exists(TOKENIZER_PATH): | |
| raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}") | |
| TRAIN_LABEL_MEAN = torch.tensor(-0.364792) | |
| TRAIN_LABEL_STD = torch.tensor(0.273129) | |
| def _quiet_call(fn, *args, **kwargs): | |
| with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): | |
| return fn(*args, **kwargs) | |
| # ------------------------- | |
| # LOAD TOKENIZER | |
| # ------------------------- | |
| tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH) | |
| # ------------------------- | |
| # LOAD MODEL | |
| # ------------------------- | |
| base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small") | |
| base_model_output_size = 512 | |
| # Match embedding matrix size to the tokenizer used during training. | |
| base_model.resize_token_embeddings(len(tokenizer)) | |
| model = T5Predictor( | |
| base_model, | |
| base_model_output_size, | |
| drop_rate=0.1, | |
| pooling="mean", | |
| ) | |
| # ------------------------- | |
| # LOAD WEIGHTS | |
| # ------------------------- | |
| state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE) | |
| # Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint. | |
| checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0] | |
| if model.model.shared.weight.shape[0] != checkpoint_vocab_size: | |
| model.model.resize_token_embeddings(checkpoint_vocab_size) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(DEVICE) | |
| model.eval() | |
| # ------------------------- | |
| # PREDICT FUNCTION | |
| # ------------------------- | |
| def predict_fepa(text, max_length=256): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=max_length, | |
| ) | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| _, prediction_norm = model(input_ids, attention_mask) | |
| prediction_fepa = z_denormalize( | |
| prediction_norm.squeeze(), | |
| TRAIN_LABEL_MEAN, | |
| TRAIN_LABEL_STD, | |
| ).item() | |
| return prediction_fepa | |
| # ------------------------- | |
| # TEST | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Predict formation_energy_per_atom from text") | |
| parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length") | |
| parser.add_argument( | |
| "--text", | |
| type=str, | |
| default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li–O bond distances ranging from 1.98–2.25 Å. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo–O bond distances ranging from 1.74–2.46 Å. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15–44°. There are a spread of Mo–O bond distances ranging from 1.77–1.82 Å. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al–O bond distances ranging from 1.88–1.95 Å. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.", | |
| help="Input text to predict FEPA", | |
| ) | |
| args = parser.parse_args() | |
| value = predict_fepa(args.text, max_length=args.max_length) | |
| print(f"Predicted formation_energy_per_atom: {value:.6f}") |