gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/seq2struct
/datasets
/spider_lib
/preprocess
/get_tables-ORI.py
| import os | |
| import sys | |
| import json | |
| import sqlite3 | |
| from os import listdir, makedirs | |
| from os.path import isfile, isdir, join, split, exists, splitext | |
| from nltk import word_tokenize, tokenize | |
| import traceback | |
| EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"} | |
| def convert_fk_index(data): | |
| fk_holder = [] | |
| for fk in data["foreign_keys"]: | |
| tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1] | |
| ref_cid, cid = None, None | |
| try: | |
| tid = data['table_names_original'].index(tn) | |
| ref_tid = data['table_names_original'].index(ref_tn) | |
| for i, (tab_id, col_org) in enumerate(data['column_names_original']): | |
| if tab_id == ref_tid and ref_col == col_org: | |
| ref_cid = i | |
| elif tid == tab_id and col == col_org: | |
| cid = i | |
| if ref_cid and cid: | |
| fk_holder.append([cid, ref_cid]) | |
| except: | |
| traceback.print_exc() | |
| print("table_names_original: ", data['table_names_original']) | |
| print("finding tab name: ", tn, ref_tn) | |
| sys.exit() | |
| return fk_holder | |
| def dump_db_json_schema(db, f): | |
| '''read table and column info''' | |
| conn = sqlite3.connect(db) | |
| conn.execute('pragma foreign_keys=ON') | |
| cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| data = {'db_id': f, | |
| 'table_names_original': [], | |
| 'table_names': [], | |
| 'column_names_original': [(-1, '*')], | |
| 'column_names': [(-1, '*')], | |
| 'column_types': ['text'], | |
| 'primary_keys': [], | |
| 'foreign_keys': []} | |
| fk_holder = [] | |
| for i, item in enumerate(cursor.fetchall()): | |
| table_name = item[0] | |
| data['table_names_original'].append(table_name) | |
| data['table_names'].append(table_name.lower().replace("_", ' ')) | |
| fks = conn.execute("PRAGMA foreign_key_list('{}') ".format(table_name)).fetchall() | |
| #print("db:{} table:{} fks:{}".format(f,table_name,fks)) | |
| fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks]) | |
| cur = conn.execute("PRAGMA table_info('{}') ".format(table_name)) | |
| for j, col in enumerate(cur.fetchall()): | |
| data['column_names_original'].append((i, col[1])) | |
| data['column_names'].append((i, col[1].lower().replace("_", " "))) | |
| #varchar, '' -> text, int, numeric -> integer, | |
| col_type = col[2].lower() | |
| if 'char' in col_type or col_type == '' or 'text' in col_type or 'var' in col_type: | |
| data['column_types'].append('text') | |
| elif 'int' in col_type or 'numeric' in col_type or 'decimal' in col_type or 'number' in col_type\ | |
| or 'id' in col_type or 'real' in col_type or 'double' in col_type or 'float' in col_type: | |
| data['column_types'].append('number') | |
| elif 'date' in col_type or 'time' in col_type or 'year' in col_type: | |
| data['column_types'].append('time') | |
| elif 'boolean' in col_type: | |
| data['column_types'].append('boolean') | |
| else: | |
| data['column_types'].append('others') | |
| if col[5] == 1: | |
| data['primary_keys'].append(len(data['column_names'])-1) | |
| data["foreign_keys"] = fk_holder | |
| data['foreign_keys'] = convert_fk_index(data) | |
| return data | |
| if __name__ == '__main__': | |
| if len(sys.argv) < 2: | |
| print("Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]") | |
| sys.exit() | |
| input_dir = sys.argv[1] | |
| output_file = sys.argv[2] | |
| ex_tab_file = sys.argv[3] | |
| all_fs = [df for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))] | |
| with open(ex_tab_file) as f: | |
| ex_tabs = json.load(f) | |
| #for tab in ex_tabs: | |
| # tab["foreign_keys"] = convert_fk_index(tab) | |
| ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs} | |
| print("precessed file num: ", len(ex_tabs)) | |
| not_fs = [df for df in listdir(input_dir) if not exists(join(input_dir, df, df+'.sqlite'))] | |
| for d in not_fs: | |
| print("no sqlite file found in: ", d) | |
| db_files = [(df+'.sqlite', df) for df in listdir(input_dir) if exists(join(input_dir, df, df+'.sqlite'))] | |
| tables = [] | |
| for f, df in db_files: | |
| #if df in ex_tabs.keys(): | |
| #print 'reading old db: ', df | |
| # tables.append(ex_tabs[df]) | |
| db = join(input_dir, df, f) | |
| print('\nreading new db: ', df) | |
| table = dump_db_json_schema(db, df) | |
| prev_tab_num = len(ex_tabs[df]["table_names"]) | |
| prev_col_num = len(ex_tabs[df]["column_names"]) | |
| cur_tab_num = len(table["table_names"]) | |
| cur_col_num = len(table["column_names"]) | |
| if df in ex_tabs.keys() and prev_tab_num == cur_tab_num and prev_col_num == cur_col_num and prev_tab_num != 0 and len(ex_tabs[df]["column_names"]) > 1: | |
| table["table_names"] = ex_tabs[df]["table_names"] | |
| table["column_names"] = ex_tabs[df]["column_names"] | |
| else: | |
| print("\n----------------------------------problem db: ", df) | |
| tables.append(table) | |
| print("final db num: ", len(tables)) | |
| with open(output_file, 'wt') as out: | |
| json.dump(tables, out, sort_keys=True, indent=2, separators=(',', ': ')) |