VikramR commited on
Commit
56eb931
·
1 Parent(s): aaebde6

Autocrop off by default, added progress bar for batch prediction, made colors nice, fixed download button

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -122,7 +122,7 @@ current_model_type = None
122
 
123
  results_cache: dict[str, str] = {}
124
  current_image = None
125
- autocrop = True
126
 
127
  temp_files: list[str] = []
128
  all_images: list[str] = []
@@ -157,7 +157,7 @@ def display_model():
157
  """
158
  global current_model_type
159
  model_name = model_names[current_model_type]
160
- return f"Current Model Type: {model_name}. Reupload model to change it."
161
 
162
 
163
  def clear():
@@ -229,9 +229,9 @@ def predict(img: str) -> gr.BarPlot:
229
  )
230
 
231
 
232
- def predict_all():
233
  global all_images, results_cache
234
- for img in all_images:
235
  current_image_name = Path(img).name
236
  img = Image.open(img).convert("RGB")
237
  if autocrop:
@@ -243,6 +243,7 @@ def predict_all():
243
  "Distribution": result,
244
  "Classification": {"Probability": prob, "Label": label},
245
  }
 
246
 
247
 
248
  def get_results_cache():
@@ -302,8 +303,9 @@ with gr.Blocks() as demo:
302
  gr.Textbox(
303
  "Only use this application on the following classes of nematodes: "
304
  + "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, "
305
- "Pratylenchus, Trichodorus, and Tylenchorhynchus\n\n"
306
- + "Only use images containing a single nematode.",
 
307
  text_align="center",
308
  label="DISCLAIMER",
309
  )
@@ -315,7 +317,7 @@ with gr.Blocks() as demo:
315
  model_select = gr.Dropdown(
316
  choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"],
317
  value="EfficientNetV2-S",
318
- label="Select Model Architecture",
319
  )
320
  with gr.Row():
321
  with gr.Column():
@@ -324,14 +326,17 @@ with gr.Blocks() as demo:
324
  show_label=False,
325
  )
326
  files = gr.File(file_types=["image"], file_count="multiple")
327
- batch_predict = gr.Button("Predict All")
 
 
 
328
 
329
  with gr.Column():
330
  mid_col_text = gr.Textbox(
331
  "Crop Image Here (Optional), then Click Run to Predict",
332
  show_label=False,
333
  )
334
- autocrop_toggle = gr.Checkbox(value=True, label="Automatic Cropping")
335
  cropper = gr.ImageEditor(
336
  type="filepath",
337
  sources=None,
@@ -352,7 +357,7 @@ with gr.Blocks() as demo:
352
  interactive=False,
353
  mirror_webcam=False,
354
  )
355
- classify = gr.Button("Classify")
356
  plot = gr.BarPlot()
357
 
358
  with gr.Row():
@@ -361,16 +366,15 @@ with gr.Blocks() as demo:
361
  label="Predictions",
362
  )
363
  with gr.Row():
364
- json_results = gr.JSON()
365
- download = gr.DownloadButton(
366
- "Download Predictions (Click again if download does not start, Gradio bug)"
367
- )
368
 
369
  download.click(save_results, outputs=download)
370
  model_select.change(load_model, inputs=model_select).then(
371
  display_model, outputs=model_text
372
  )
373
- model_select
374
 
375
  files.upload(upload_files, inputs=files)
376
  files.select(select_image, inputs=files, outputs=cropper).then(
@@ -388,15 +392,19 @@ with gr.Blocks() as demo:
388
  outputs=preview,
389
  )
390
 
391
- batch_predict.click(predict_all).then(get_results_cache, outputs=json_results)
 
 
392
 
393
- files.clear(clear).then(get_results_cache, outputs=json_results)
 
 
394
 
395
  crop.click(show_preview, inputs=cropper, outputs=preview)
396
 
397
  classify.click(predict, inputs=preview, outputs=plot).then(
398
  get_results_cache, outputs=json_results
399
- )
400
  demo.unload(clear)
401
 
402
 
 
122
 
123
  results_cache: dict[str, str] = {}
124
  current_image = None
125
+ autocrop = False
126
 
127
  temp_files: list[str] = []
128
  all_images: list[str] = []
 
157
  """
158
  global current_model_type
159
  model_name = model_names[current_model_type]
160
+ return f"Current Model Type: {model_name}. Use dropdown on the right to change it."
161
 
162
 
163
  def clear():
 
229
  )
230
 
231
 
232
+ def predict_all(progress_bar=gr.Progress()):
233
  global all_images, results_cache
234
+ for img in progress_bar.tqdm(all_images, desc="Running images"):
235
  current_image_name = Path(img).name
236
  img = Image.open(img).convert("RGB")
237
  if autocrop:
 
243
  "Distribution": result,
244
  "Classification": {"Probability": prob, "Label": label},
245
  }
246
+ return "All images predicted successfully."
247
 
248
 
249
  def get_results_cache():
 
303
  gr.Textbox(
304
  "Only use this application on the following classes of nematodes: "
305
  + "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, "
306
+ "Pratylenchus, Trichodorus, and Tylenchorhynchus.\n\n"
307
+ + "Only use images containing a single nematode.\n\n"
308
+ + "SCROLL DOWN TO DOWNLOAD THE PREDICTIONS FOR YOUR IMAGES!",
309
  text_align="center",
310
  label="DISCLAIMER",
311
  )
 
317
  model_select = gr.Dropdown(
318
  choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"],
319
  value="EfficientNetV2-S",
320
+ label="Select Model Architecture (May take a few moments, check text on the left to confirm your model has loaded)",
321
  )
322
  with gr.Row():
323
  with gr.Column():
 
326
  show_label=False,
327
  )
328
  files = gr.File(file_types=["image"], file_count="multiple")
329
+ batch_predict = gr.Button("Predict All", variant="stop")
330
+ prediction_progress = gr.Textbox(
331
+ "Prediction Progress Bar", show_label=False
332
+ )
333
 
334
  with gr.Column():
335
  mid_col_text = gr.Textbox(
336
  "Crop Image Here (Optional), then Click Run to Predict",
337
  show_label=False,
338
  )
339
+ autocrop_toggle = gr.Checkbox(value=False, label="Automatic Cropping")
340
  cropper = gr.ImageEditor(
341
  type="filepath",
342
  sources=None,
 
357
  interactive=False,
358
  mirror_webcam=False,
359
  )
360
+ classify = gr.Button("Classify", variant="stop")
361
  plot = gr.BarPlot()
362
 
363
  with gr.Row():
 
366
  label="Predictions",
367
  )
368
  with gr.Row():
369
+ with gr.Column():
370
+ json_results = gr.JSON()
371
+ with gr.Column():
372
+ download = gr.DownloadButton("Download Predictions", variant="primary")
373
 
374
  download.click(save_results, outputs=download)
375
  model_select.change(load_model, inputs=model_select).then(
376
  display_model, outputs=model_text
377
  )
 
378
 
379
  files.upload(upload_files, inputs=files)
380
  files.select(select_image, inputs=files, outputs=cropper).then(
 
392
  outputs=preview,
393
  )
394
 
395
+ batch_predict.click(predict_all, outputs=prediction_progress).then(
396
+ get_results_cache, outputs=json_results
397
+ ).then(save_results, outputs=download)
398
 
399
+ files.clear(clear).then(get_results_cache, outputs=json_results).then(
400
+ save_results, outputs=download
401
+ )
402
 
403
  crop.click(show_preview, inputs=cropper, outputs=preview)
404
 
405
  classify.click(predict, inputs=preview, outputs=plot).then(
406
  get_results_cache, outputs=json_results
407
+ ).then(save_results, outputs=download)
408
  demo.unload(clear)
409
 
410