| | import json |
| | import os |
| | import pickle |
| | import random |
| | import time |
| | from collections import Counter |
| | from datetime import datetime |
| | from glob import glob |
| |
|
| | import gdown |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import pandas as pd |
| | import seaborn as sns |
| | import streamlit as st |
| | from PIL import Image |
| |
|
| | import SessionState |
| | from download_utils import * |
| | from image_utils import * |
| |
|
| | random.seed(datetime.now()) |
| | np.random.seed(int(time.time())) |
| |
|
| | NUMBER_OF_TRIALS = 20 |
| | CLASSIFIER_TAG = "" |
| | explaination_functions = [load_chm_nns, load_knn_nns] |
| | selected_xai_tool = None |
| |
|
| | |
| | folder_to_name = {} |
| | class_descriptions = {} |
| | classifier_predictions = {} |
| | selected_dataset = "Final" |
| |
|
| | root_visualization_dir = "./visualizations/" |
| | viz_url = "https://drive.google.com/uc?id=1LpmOc_nFBzApYWAokO2J-s9RRXsk3pBN" |
| | viz_archivefile = "Final.zip" |
| |
|
| | demonstration_url = "https://drive.google.com/uc?id=1C92llG5VrlABrsIEvxfNlSDc_gIeLlls" |
| | demonst_zipfile = "demonstrations.zip" |
| |
|
| | picklefile_url = "https://drive.google.com/uc?id=1Yx4abA4VLZGO5JkzhXVGdy6mbPltMd68" |
| | prediction_root = "./predictions/" |
| | prediction_pickle = f"{prediction_root}predictions.pickle" |
| |
|
| |
|
| | |
| | download_files( |
| | root_visualization_dir, |
| | viz_url, |
| | viz_archivefile, |
| | demonstration_url, |
| | demonst_zipfile, |
| | picklefile_url, |
| | prediction_root, |
| | prediction_pickle, |
| | ) |
| | |
| | |
| | app_mode = "" |
| |
|
| | |
| | with open("imagenet-labels.json", "rb") as f: |
| | folder_to_name = json.load(f) |
| |
|
| | with open("gloss.txt", "r") as f: |
| | description_file = f.readlines() |
| |
|
| | class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file} |
| | |
| |
|
| | with open(prediction_pickle, "rb") as f: |
| | classifier_predictions = pickle.load(f) |
| |
|
| | |
| | session_state = SessionState.get( |
| | page=1, |
| | first_run=1, |
| | user_feedback={}, |
| | queries=[], |
| | is_classifier_correct={}, |
| | XAI_tool="Unselected", |
| | ) |
| | |
| |
|
| |
|
| | def resmaple_queries(): |
| | if session_state.first_run == 1: |
| | both_correct = glob( |
| | root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG" |
| | ) |
| | both_wrong = glob( |
| | root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG" |
| | ) |
| |
|
| | correct_samples = list( |
| | np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False) |
| | ) |
| | wrong_samples = list( |
| | np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False) |
| | ) |
| |
|
| | all_images = correct_samples + wrong_samples |
| | random.shuffle(all_images) |
| | session_state.queries = all_images |
| | session_state.first_run = -1 |
| | |
| | session_state.user_feedback = {} |
| | session_state.is_classifier_correct = {} |
| |
|
| |
|
| | def render_experiment(query): |
| | current_query = session_state.queries[query] |
| | query_id = os.path.basename(current_query) |
| |
|
| | predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"] |
| | prediction_confidence = classifier_predictions[query_id][ |
| | f"{CLASSIFIER_TAG}-confidence" |
| | ] |
| | prediction_label = folder_to_name[predicted_wnid] |
| | class_def = class_descriptions[predicted_wnid] |
| |
|
| | session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][ |
| | f"{CLASSIFIER_TAG.upper()}-Output" |
| | ] |
| |
|
| | |
| |
|
| | col1, col2 = st.columns(2) |
| | with col1: |
| | st.image(load_query(current_query), caption=f"Query ID: {query_id}") |
| | with col2: |
| | |
| | with st.expander("Show Class Description"): |
| | st.write(f"**Name**: {prediction_label}") |
| | st.write("**Class Definition**:") |
| | st.markdown("`" + class_def + "`") |
| | st.image( |
| | Image.open(f"demonstrations/{predicted_wnid}.jpeg"), |
| | caption=f"Class Explanation", |
| | use_column_width=True, |
| | ) |
| |
|
| | default_value = 0 |
| | if query_id in session_state.user_feedback.keys(): |
| | if session_state.user_feedback[query_id] == "Correct": |
| | default_value = 1 |
| | elif session_state.user_feedback[query_id] == "Wrong": |
| | default_value = 2 |
| |
|
| | session_state.user_feedback[query_id] = st.radio( |
| | "What do you think about model's prediction?", |
| | ("-", "Correct", "Wrong"), |
| | key=query_id, |
| | index=default_value, |
| | ) |
| | st.write(f"**Model Prediction**: {prediction_label}") |
| | st.write(f"**Model Confidence**: {prediction_confidence}") |
| |
|
| | |
| | if selected_xai_tool is not None: |
| | st.image( |
| | selected_xai_tool(current_query), |
| | caption=f"Explaination", |
| | use_column_width=True, |
| | ) |
| |
|
| | |
| |
|
| | if st.button("Debug: Show Everything"): |
| | st.image(Image.open(current_query)) |
| |
|
| |
|
| | def render_results(): |
| | user_correct_guess = 0 |
| | for q in session_state.user_feedback.keys(): |
| | uf = True if session_state.user_feedback[q] == "Correct" else False |
| | if session_state.is_classifier_correct[q] == uf: |
| | user_correct_guess += 1 |
| |
|
| | st.write( |
| | f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct" |
| | ) |
| | st.markdown("## User Performance Breakdown") |
| |
|
| | categories = [ |
| | "Correct", |
| | "Wrong", |
| | ] |
| | breakdown_stats_correct = {c: 0 for c in categories} |
| | breakdown_stats_wrong = {c: 0 for c in categories} |
| |
|
| | experiment_summary = [] |
| |
|
| | for q in session_state.user_feedback.keys(): |
| | category = "Correct" if session_state.is_classifier_correct[q] else "Wrong" |
| | is_user_correct = category == session_state.user_feedback[q] |
| |
|
| | if is_user_correct: |
| | breakdown_stats_correct[category] += 1 |
| | else: |
| | breakdown_stats_wrong[category] += 1 |
| |
|
| | experiment_summary.append( |
| | [ |
| | q, |
| | classifier_predictions[q]["real-gts"], |
| | folder_to_name[ |
| | classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"] |
| | ], |
| | category, |
| | session_state.user_feedback[q], |
| | is_user_correct, |
| | ] |
| | ) |
| | |
| | experiment_summary_df = pd.DataFrame.from_records( |
| | experiment_summary, |
| | columns=[ |
| | "Query", |
| | "GT Labels", |
| | f"{CLASSIFIER_TAG} Prediction", |
| | "Category", |
| | "User Prediction", |
| | "Is User Prediction Correct", |
| | ], |
| | ) |
| | st.write("Summary", experiment_summary_df) |
| |
|
| | csv = convert_df(experiment_summary_df) |
| | st.download_button( |
| | "Press to Download", csv, "summary.csv", "text/csv", key="download-records" |
| | ) |
| | |
| | user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg( |
| | {"Is User Prediction Correct": ["count", "sum", "mean"]} |
| | ) |
| | |
| | user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0) |
| | user_pf_by_model_pred.columns = [ |
| | "Count", |
| | "Correct User Guess", |
| | "Mean User Performance", |
| | ] |
| | user_pf_by_model_pred.index.name = "Model Prediction" |
| | st.write("User performance break down by Model prediction:", user_pf_by_model_pred) |
| | csv = convert_df(user_pf_by_model_pred) |
| | st.download_button( |
| | "Press to Download", |
| | csv, |
| | "user-performance-by-model-prediction.csv", |
| | "text/csv", |
| | key="download-performance-by-model-prediction", |
| | ) |
| | |
| |
|
| | confusion_matrix = pd.crosstab( |
| | experiment_summary_df["Category"], |
| | experiment_summary_df["User Prediction"], |
| | rownames=["Actual"], |
| | colnames=["Predicted"], |
| | ) |
| | st.write("Confusion Matrix", confusion_matrix) |
| | csv = convert_df(confusion_matrix) |
| | st.download_button( |
| | "Press to Download", |
| | csv, |
| | "confusion-matrix.csv", |
| | "text/csv", |
| | key="download-confusiion-matrix", |
| | ) |
| |
|
| |
|
| | def render_menu(): |
| | |
| | readme_text = st.markdown( |
| | """ |
| | # Instructions |
| | ``` |
| | When testing this study, you should first see the class definition, then hide the expander and see the query. |
| | ``` |
| | """ |
| | ) |
| |
|
| | app_mode = st.selectbox( |
| | "Choose the page to show:", |
| | ["Experiment Instruction", "Start Experiment", "See the Results"], |
| | ) |
| |
|
| | if app_mode == "Experiment Instruction": |
| | st.success("To continue select an option in the dropdown menu.") |
| | elif app_mode == "Start Experiment": |
| | |
| | readme_text.empty() |
| |
|
| | page_id = session_state.page |
| | col1, col4, col2, col3 = st.columns(4) |
| | prev_page = col1.button("Previous Image") |
| |
|
| | if prev_page: |
| | page_id -= 1 |
| | if page_id < 1: |
| | page_id = 1 |
| |
|
| | next_page = col2.button("Next Image") |
| |
|
| | if next_page: |
| | page_id += 1 |
| | if page_id > NUMBER_OF_TRIALS: |
| | page_id = NUMBER_OF_TRIALS |
| |
|
| | if page_id == NUMBER_OF_TRIALS: |
| | st.success( |
| | 'You have reached the last image. Please go to the "Results" page to see your performance.' |
| | ) |
| | if st.button("View"): |
| | app_mode = "See the Results" |
| |
|
| | if col3.button("Resample"): |
| | st.write("Restarting ...") |
| | page_id = 1 |
| | session_state.first_run = 1 |
| | resmaple_queries() |
| |
|
| | session_state.page = page_id |
| | st.write(f"Render Experiment: {session_state.page}") |
| | render_experiment(session_state.page - 1) |
| | elif app_mode == "See the Results": |
| | readme_text.empty() |
| | st.write("Results Summary") |
| | render_results() |
| |
|
| |
|
| | def main(): |
| | global app_mode |
| | global session_state |
| | global selected_xai_tool |
| | global CLASSIFIER_TAG |
| |
|
| | |
| | |
| | st.set_page_config(layout="wide") |
| | st.title("Visual CorrespondenceHuman Study - ImageNet") |
| |
|
| | options = [ |
| | "Unselected", |
| | "NOXAI", |
| | "KNN", |
| | "EMD-Corr Nearest Neighbors", |
| | "EMD-Corr Correspondence", |
| | "CHM-Corr Nearest Neighbors", |
| | "CHM-Corr Correspondence", |
| | ] |
| |
|
| | st.markdown( |
| | """ <style> |
| | div[role="radiogroup"] > :first-child{ |
| | display: none !important; |
| | } |
| | </style> |
| | """, |
| | unsafe_allow_html=True, |
| | ) |
| |
|
| | if session_state.XAI_tool == "Unselected": |
| | default = options.index(session_state.XAI_tool) |
| | session_state.XAI_tool = st.radio( |
| | "What explaination tool do you want to evaluate?", |
| | options, |
| | key="which_xai", |
| | index=default, |
| | ) |
| | |
| |
|
| | if session_state.XAI_tool != "Unselected": |
| | st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``") |
| |
|
| | if session_state.XAI_tool == "NOXAI": |
| | CLASSIFIER_TAG = "knn" |
| | selected_xai_tool = None |
| | elif session_state.XAI_tool == "KNN": |
| | selected_xai_tool = load_knn_nns |
| | CLASSIFIER_TAG = "knn" |
| | elif session_state.XAI_tool == "CHM-Corr Nearest Neighbors": |
| | selected_xai_tool = load_chm_nns |
| | CLASSIFIER_TAG = "CHM" |
| | elif session_state.XAI_tool == "CHM-Corr Correspondence": |
| | selected_xai_tool = load_chm_corrs |
| | CLASSIFIER_TAG = "CHM" |
| | elif session_state.XAI_tool == "EMD-Corr Nearest Neighbors": |
| | selected_xai_tool = load_emd_nns |
| | CLASSIFIER_TAG = "EMD" |
| | elif session_state.XAI_tool == "EMD-Corr Correspondence": |
| | selected_xai_tool = load_emd_corrs |
| | CLASSIFIER_TAG = "EMD" |
| |
|
| | resmaple_queries() |
| | render_menu() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|