Spaces:
Runtime error
Runtime error
| ############################################################################################################################# | |
| # Filename : app.py | |
| # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs. | |
| # Author : Georgios Ioannou | |
| # | |
| # Copyright © 2024 by Georgios Ioannou | |
| ############################################################################################################################# | |
| # Import libraries. | |
| import os | |
| import requests | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from dataclasses import dataclass | |
| from dotenv import find_dotenv, load_dotenv | |
| from huggingface_hub import InferenceClient | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.chains import ConversationChain | |
| from langchain.llms import OpenAI | |
| from policies import complex_policy, simple_policy | |
| from typing import Literal | |
| ############################################################################################################################# | |
| # Load environment variable(s). | |
| load_dotenv(find_dotenv()) | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
| ############################################################################################################################# | |
| class Message: | |
| """Class for keeping track of a chat message.""" | |
| origin: Literal["human", "ai"] | |
| message: str | |
| ############################################################################################################################# | |
| # Initialize Hugging Face clients. | |
| def initialize_hf_clients(): | |
| client = InferenceClient(api_key=HUGGINGFACE_API_KEY) | |
| gpt2_api_url = "https://api-inference.huggingface.co/models/openai-community/gpt2" | |
| headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"} | |
| return client, gpt2_api_url, headers | |
| ############################################################################################################################# | |
| # Hugging Face model inference functions. | |
| def qwen_inference(prompt): | |
| client, _, _ = initialize_hf_clients() | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| response = client.chat.completions.create( | |
| model="Qwen/Qwen2.5-1.5B-Instruct", messages=messages, max_tokens=500 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"Error with Qwen inference: {str(e)}" | |
| def gpt2_inference(prompt): | |
| _, api_url, headers = initialize_hf_clients() | |
| try: | |
| response = requests.post(api_url, headers=headers, json={"inputs": prompt}) | |
| return response.json()[0]["generated_text"] | |
| except Exception as e: | |
| return f"Error with GPT-2 inference: {str(e)}" | |
| ############################################################################################################################# | |
| # Initialize session state. | |
| def initialize_session_state(): | |
| if "history" not in st.session_state: | |
| st.session_state.history = [] | |
| if "token_count" not in st.session_state: | |
| st.session_state.token_count = 0 | |
| if "policy" not in st.session_state: | |
| st.session_state.policy = "No Policy" | |
| if "selected_ai_icon" not in st.session_state: | |
| st.session_state.selected_ai_icon = "a1.png" | |
| if "selected_user_icon" not in st.session_state: | |
| st.session_state.selected_user_icon = "man.png" | |
| ############################################################################################################################# | |
| # Function to apply local CSS. | |
| def local_css(file_name): | |
| with open(file_name) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| ############################################################################################################################# | |
| def on_click_callback(): | |
| with get_openai_callback() as cb: | |
| human_prompt = st.session_state.human_prompt | |
| # Policy + User Query Text. | |
| role = "User" | |
| task = f""" | |
| Task: Check if there is unsafe content in | |
| '{role}' messages in conversations | |
| according our safety policy with the below categories. | |
| """ | |
| output_format = f""" | |
| Provide your safety assessment for '{role}' in the above conversation: | |
| - First line must read 'safe' or 'unsafe'. | |
| - If unsafe, a second line must include a comma-separated list of violated categories. | |
| """ | |
| query = human_prompt | |
| conversation = f""" | |
| <BEGIN CONVERSATION> | |
| User: {query} | |
| <END CONVERSATION> | |
| """ | |
| if st.session_state.policy == "Simple Policy": | |
| prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}" | |
| elif st.session_state.policy == "Complex Policy": | |
| prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}" | |
| else: | |
| prompt = human_prompt | |
| # Safety check 1 - Input check. | |
| if ( | |
| "gpt" in st.session_state.model.lower() | |
| and "gpt2" not in st.session_state.model.lower() | |
| ): | |
| llm_response_safety_check_1 = st.session_state.conversation.run(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| elif "qwen" in st.session_state.model.lower(): | |
| llm_response_safety_check_1 = qwen_inference(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| else: # gpt2. | |
| llm_response_safety_check_1 = gpt2_inference(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| st.session_state.history.append(Message("human", human_prompt)) | |
| if "unsafe" in llm_response_safety_check_1.lower(): | |
| st.session_state.history.append(Message("ai", llm_response_safety_check_1)) | |
| return | |
| # Get model response. | |
| if ( | |
| "gpt" in st.session_state.model.lower() | |
| and "gpt2" not in st.session_state.model.lower() | |
| ): | |
| conversation_chain = ConversationChain( | |
| llm=OpenAI( | |
| temperature=0.2, | |
| openai_api_key=OPENAI_API_KEY, | |
| model_name=st.session_state.model, | |
| ) | |
| ) | |
| llm_response = conversation_chain.run(human_prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| elif "qwen" in st.session_state.model.lower(): | |
| llm_response = qwen_inference(human_prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| else: # gpt2. | |
| llm_response = gpt2_inference(human_prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| # Safety check 2 - Output check. | |
| query = llm_response | |
| conversation = f""" | |
| <BEGIN CONVERSATION> | |
| User: {query} | |
| <END CONVERSATION> | |
| """ | |
| if st.session_state.policy == "Simple Policy": | |
| prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}" | |
| elif st.session_state.policy == "Complex Policy": | |
| prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}" | |
| else: | |
| prompt = llm_response | |
| if ( | |
| "gpt" in st.session_state.model.lower() | |
| and "gpt2" not in st.session_state.model.lower() | |
| ): | |
| llm_response_safety_check_2 = st.session_state.conversation.run(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| elif "qwen" in st.session_state.model.lower(): | |
| llm_response_safety_check_2 = qwen_inference(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| else: # gpt2. | |
| llm_response_safety_check_2 = gpt2_inference(prompt) | |
| st.session_state.token_count += cb.total_tokens | |
| if "unsafe" in llm_response_safety_check_2.lower(): | |
| st.session_state.history.append( | |
| Message( | |
| "ai", | |
| "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", | |
| ) | |
| ) | |
| else: | |
| st.session_state.history.append(Message("ai", llm_response)) | |
| ############################################################################################################################# | |
| def main(): | |
| initialize_session_state() | |
| # Page title and favicon. | |
| st.set_page_config(page_title="Responsible AI", page_icon="⚖️") | |
| # Load CSS. | |
| local_css("./static/styles/styles.css") | |
| # Title. | |
| title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem"> | |
| Responsible AI</h1>""" | |
| st.markdown(title, unsafe_allow_html=True) | |
| # Subtitle 1. | |
| subtitle1 = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem"> | |
| Showcase the importance of Responsible AI in LLMs Using Policies</h3>""" | |
| st.markdown(subtitle1, unsafe_allow_html=True) | |
| # Subtitle 2. | |
| subtitle2 = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem"> | |
| CUNY Tech Prep Tutorial 6</h2>""" | |
| st.markdown(subtitle2, unsafe_allow_html=True) | |
| # Image. | |
| image = "./static/ctp.png" | |
| left_co, cent_co, last_co = st.columns(3) | |
| with cent_co: | |
| st.image(image=image) | |
| # Sidebar dropdown menu for Models. | |
| models = [ | |
| "gpt-3.5-turbo", | |
| "gpt-3.5-turbo-instruct", | |
| "gpt-4-turbo", | |
| "gpt-4", | |
| "Qwen2.5-1.5B-Instruct", | |
| "gpt2", | |
| ] | |
| selected_model = st.sidebar.selectbox("Select Model:", models) | |
| st.sidebar.markdown( | |
| f"<span style='color: white;'>Current Model: {selected_model}</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.session_state.model = selected_model | |
| if "gpt" in selected_model.lower() and "gpt2" not in selected_model.lower(): | |
| st.session_state.conversation = ConversationChain( | |
| llm=OpenAI( | |
| temperature=0.2, | |
| openai_api_key=OPENAI_API_KEY, | |
| model_name=st.session_state.model, | |
| ), | |
| ) | |
| # Sidebar dropdown menu for Policies. | |
| policies = ["No Policy", "Complex Policy", "Simple Policy"] | |
| selected_policy = st.sidebar.selectbox("Select Policy:", policies) | |
| st.sidebar.markdown( | |
| f"<span style='color: white;'>Current Policy: {selected_policy}</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.session_state.policy = selected_policy | |
| # Sidebar dropdown menu for AI Icons. | |
| ai_icons = ["AI 1", "AI 2"] | |
| selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) | |
| st.sidebar.markdown( | |
| f"<span style='color: white;'>Current AI Icon: {selected_ai_icon}</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| if selected_ai_icon == "AI 1": | |
| st.session_state.selected_ai_icon = "ai1.png" | |
| elif selected_ai_icon == "AI 2": | |
| st.session_state.selected_ai_icon = "ai2.png" | |
| # Sidebar dropdown menu for User Icons. | |
| user_icons = ["Man", "Woman"] | |
| selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) | |
| st.sidebar.markdown( | |
| f"<span style='color: white;'>Current User Icon: {selected_user_icon}</span>", | |
| unsafe_allow_html=True, | |
| ) | |
| if selected_user_icon == "Man": | |
| st.session_state.selected_user_icon = "man.png" | |
| elif selected_user_icon == "Woman": | |
| st.session_state.selected_user_icon = "woman.png" | |
| # Chat interface. | |
| chat_placeholder = st.container() | |
| prompt_placeholder = st.form("chat-form") | |
| token_placeholder = st.empty() | |
| with chat_placeholder: | |
| for chat in st.session_state.history: | |
| div = f""" | |
| <div class="chat-row | |
| {'' if chat.origin == 'ai' else 'row-reverse'}"> | |
| <img class="chat-icon" src="app/static/{ | |
| st.session_state.selected_ai_icon if chat.origin == 'ai' | |
| else st.session_state.selected_user_icon}" | |
| width=32 height=32> | |
| <div class="chat-bubble | |
| {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
| ​{chat.message} | |
| </div> | |
| </div> | |
| """ | |
| st.markdown(div, unsafe_allow_html=True) | |
| for _ in range(3): | |
| st.markdown("") | |
| # User prompt. | |
| with prompt_placeholder: | |
| st.markdown("**Chat**") | |
| cols = st.columns((6, 1)) | |
| cols[0].text_input( | |
| "Chat", | |
| placeholder="What is your question?", | |
| label_visibility="collapsed", | |
| key="human_prompt", | |
| ) | |
| cols[1].form_submit_button( | |
| "Submit", | |
| type="primary", | |
| on_click=on_click_callback, | |
| ) | |
| token_placeholder.caption(f"Used {st.session_state.token_count} tokens\n") | |
| # GitHub repository link. | |
| st.markdown( | |
| f""" | |
| <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our | |
| <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Enter key handler. | |
| components.html( | |
| """ | |
| <script> | |
| const streamlitDoc = window.parent.document; | |
| const buttons = Array.from( | |
| streamlitDoc.querySelectorAll('.stButton > button') | |
| ); | |
| const submitButton = buttons.find( | |
| el => el.innerText === 'Submit' | |
| ); | |
| streamlitDoc.addEventListener('keydown', function(e) { | |
| switch (e.key) { | |
| case 'Enter': | |
| submitButton.click(); | |
| break; | |
| } | |
| }); | |
| </script> | |
| """, | |
| height=0, | |
| width=0, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |