| | import traceback |
| | from io import StringIO |
| | from typing import Optional |
| |
|
| | import gradio as gr |
| | import pandas as pd |
| | from loguru import logger |
| |
|
| | from utils import pipeline |
| | from utils.models import list_models |
| |
|
| |
|
| | def read_data(filepath: str) -> Optional[pd.DataFrame]: |
| | if filepath.endswith('.xlsx'): |
| | df = pd.read_excel(filepath) |
| | elif filepath.endswith('.csv'): |
| | df = pd.read_csv(filepath) |
| | else: |
| | raise Exception('File type not supported') |
| | return df |
| |
|
| |
|
| | def process( |
| | task_name: str, |
| | model_name: str, |
| | pooling: str, |
| | text: str, |
| | file=None, |
| | ) -> (None, pd.DataFrame, str): |
| | try: |
| | logger.info(f'Processing {task_name} with {model_name} and {pooling}') |
| | |
| | if file: |
| | df = read_data(file.name) |
| | elif text: |
| | string_io = StringIO(text) |
| | df = pd.read_csv(string_io) |
| | assert len(df) >= 1, 'No input data' |
| | else: |
| | raise Exception('No input data') |
| |
|
| | |
| | if len(df) > 10000: |
| | raise Exception('Data exceeds 10,000 rows') |
| |
|
| | |
| | if task_name == 'Originality': |
| | df = pipeline.p0_originality(df, model_name, pooling) |
| | elif task_name == 'Flexibility': |
| | df = pipeline.p1_flexibility(df, model_name, pooling) |
| | else: |
| | raise Exception('Task not supported') |
| |
|
| | |
| | path = 'output.csv' |
| | df.to_csv(path, index=False, encoding='utf-8-sig') |
| | return None, df.iloc[:10], path |
| |
|
| | except: |
| | error = traceback.format_exc() |
| | logger.warning({ |
| | 'error': error, |
| | 'task_name': task_name, |
| | 'model_name': model_name, |
| | 'pooling': pooling, |
| | 'text': text, |
| | 'file': file, |
| | }) |
| | return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None |
| |
|
| |
|
| | |
| | task_name_dropdown = gr.components.Dropdown( |
| | label='Task Name', |
| | value='Originality', |
| | choices=['Originality', 'Flexibility'] |
| | ) |
| | model_name_dropdown = gr.components.Dropdown( |
| | label='Model Name', |
| | value=list_models[0], |
| | choices=list_models |
| | ) |
| | pooling_dropdown = gr.components.Dropdown( |
| | label='Pooling', |
| | value='mean', |
| | choices=['mean', 'cls'] |
| | ) |
| | text_input = gr.components.Textbox( |
| | value=open('data/example_xlm.csv', 'r').read(), |
| | lines=10, |
| | ) |
| | file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) |
| |
|
| | |
| | text_output = gr.components.Textbox(label='Output') |
| | dataframe_output = gr.components.Dataframe(label='DataFrame') |
| | file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) |
| |
|
| | app = gr.Interface( |
| | fn=process, |
| | inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], |
| | outputs=[text_output, dataframe_output, file_output], |
| | description=open('data/description.txt', 'r').read(), |
| | title='TransDis-CreativityAutoAssessment', |
| | concurrency_limit=1, |
| | ) |
| | app.launch(max_threads=1) |
| |
|