malavika4089 commited on
Commit
1bac357
·
verified ·
1 Parent(s): 44eda15

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import onnxruntime
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ onnx_model_path = "sarcoloring.onnx"
8
+ sess = onnxruntime.InferenceSession(onnx_model_path)
9
+
10
+ def predict(input_image):
11
+
12
+ input_image = input_image.resize((256, 256))
13
+ input_image = np.array(input_image).transpose(2, 0, 1)
14
+ input_image = input_image.astype(np.float32) / 255.0
15
+ input_image = (input_image - 0.5) / 0.5
16
+ input_image = np.expand_dims(input_image, axis=0)
17
+
18
+ # Run the model
19
+ inputs = {sess.get_inputs()[0].name: input_image}
20
+ output = sess.run(None, inputs)
21
+
22
+
23
+ output_image = output[0].squeeze().transpose(1, 2, 0)
24
+ output_image = (output_image + 1) / 2 # [0,1]
25
+ output_image = (output_image * 255).astype(np.uint8)
26
+
27
+ return Image.fromarray(output_image)
28
+
29
+
30
+ example_images = [[os.path.join("examples", fname)] for fname in os.listdir("examples")]
31
+
32
+
33
+ iface = gr.Interface(fn=predict,
34
+ inputs=gr.Image(type="pil"),
35
+ outputs=gr.Image(type="pil"),
36
+ examples=example_images
37
+ )
38
+
39
+
40
+ iface.launch()