Muhusystem commited on
Commit
9e5f3db
·
1 Parent(s): a5ee181

Split text prediction and attribution analysis into separate buttons

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -61,7 +61,6 @@ feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-2
61
 
62
  # 定义推理函数
63
  def predict_text(image, text):
64
- image = Image.fromarray(image)
65
  image_features = feature_extractor(images=image, return_tensors="pt")
66
 
67
  inputs = tokenizer.encode_plus(
@@ -84,7 +83,6 @@ def predict_text(image, text):
84
 
85
  # 定义归因分析函数
86
  def generate_attribution(image, text):
87
- image = Image.fromarray(image)
88
  image_features = feature_extractor(images=image, return_tensors="pt")
89
 
90
  inputs = tokenizer.encode_plus(
@@ -130,20 +128,21 @@ def generate_attribution(image, text):
130
  with gr.Blocks() as demo:
131
  with gr.Row():
132
  with gr.Column():
133
- input_image = gr.Image(label="Input Image", interactive=True)
134
- question_input = gr.Textbox(label="Question", lines=2, interactive=True)
135
  clear_button = gr.Button("Clear")
136
  with gr.Column():
137
  predict_button = gr.Button("Answer")
138
  prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False)
139
  attribution_button = gr.Button("Generate Attribution")
140
  with gr.Row():
141
- attribution_image_1 = gr.Image(label="Attribution Image", interactive=False)
142
- attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False)
143
 
144
  # 按钮事件绑定
145
  predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output)
146
  attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2])
147
  clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output])
148
 
 
149
  demo.launch()
 
61
 
62
  # 定义推理函数
63
  def predict_text(image, text):
 
64
  image_features = feature_extractor(images=image, return_tensors="pt")
65
 
66
  inputs = tokenizer.encode_plus(
 
83
 
84
  # 定义归因分析函数
85
  def generate_attribution(image, text):
 
86
  image_features = feature_extractor(images=image, return_tensors="pt")
87
 
88
  inputs = tokenizer.encode_plus(
 
128
  with gr.Blocks() as demo:
129
  with gr.Row():
130
  with gr.Column():
131
+ input_image = gr.Image(label="Input Image", interactive=True, height=400)
132
+ question_input = gr.Textbox(label="Question", lines=3, max_lines=3)
133
  clear_button = gr.Button("Clear")
134
  with gr.Column():
135
  predict_button = gr.Button("Answer")
136
  prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False)
137
  attribution_button = gr.Button("Generate Attribution")
138
  with gr.Row():
139
+ attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400)
140
+ attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400)
141
 
142
  # 按钮事件绑定
143
  predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output)
144
  attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2])
145
  clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output])
146
 
147
+ # 启动 Gradio 界面
148
  demo.launch()