mjavaid commited on
Commit
d58a265
·
1 Parent(s): 67bff2d

first commit

Browse files
Files changed (2) hide show
  1. app.py +151 -9
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,14 +1,156 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForImageTextToText
4
+ from PIL import Image
5
+ import io
6
+ import requests
7
+ import spaces
8
+
9
+ # Initialize model and processor globally for caching
10
+ model_id = "CohereForAI/aya-vision-8b"
11
+ processor = None
12
+ model = None
13
+
14
+ def load_model():
15
+ global processor, model
16
+ if processor is None or model is None:
17
+ try:
18
+ processor = AutoProcessor.from_pretrained(model_id)
19
+ model = AutoModelForImageTextToText.from_pretrained(
20
+ model_id, device_map="auto", torch_dtype=torch.float16
21
+ )
22
+ return "Model loaded successfully!"
23
+ except Exception as e:
24
+ return f"Error loading model: {e}\nMake sure to install the correct version of transformers with: pip install 'git+https://github.com/huggingface/[email protected]'"
25
+ return "Model already loaded!"
26
+ @spaces.gpu
27
+ def process_image_and_prompt(image, image_url, prompt, temperature=0.3, max_tokens=300):
28
+ global processor, model
29
+
30
+ # Ensure model is loaded
31
+ if processor is None or model is None:
32
+ return "Please load the model first using the 'Load Model' button."
33
+
34
+ # Process image input (either uploaded or from URL)
35
+ if image is not None:
36
+ img = Image.fromarray(image)
37
+ elif image_url and image_url.strip():
38
+ try:
39
+ response = requests.get(image_url)
40
+ img = Image.open(io.BytesIO(response.content))
41
+ except Exception as e:
42
+ return f"Error loading image from URL: {e}"
43
+ else:
44
+ return "Please provide either an image or an image URL."
45
+
46
+ # Format message with the aya-vision chat template
47
+ messages = [
48
+ {"role": "user",
49
+ "content": [
50
+ {"type": "image", "source": img},
51
+ {"type": "text", "text": prompt},
52
+ ]},
53
+ ]
54
+
55
+ # Process input
56
+ try:
57
+ inputs = processor.apply_chat_template(
58
+ messages,
59
+ padding=True,
60
+ add_generation_prompt=True,
61
+ tokenize=True,
62
+ return_dict=True,
63
+ return_tensors="pt"
64
+ ).to(model.device)
65
+
66
+ # Generate response
67
+ gen_tokens = model.generate(
68
+ **inputs,
69
+ max_new_tokens=int(max_tokens),
70
+ do_sample=True,
71
+ temperature=float(temperature),
72
+ )
73
+
74
+ response = processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
75
+ return response
76
+ except Exception as e:
77
+ return f"Error generating response: {e}"
78
+
79
+ # Define example inputs
80
+ examples = [
81
+ [None, "https://media.istockphoto.com/id/458012057/photo/istanbul-turkey.jpg?s=612x612&w=0&k=20&c=qogAOVvkpfUyqLUMr_XJQyq-HkACXyYUSZbKhBlPrxo=", "What landmark is shown in this image?", 0.3, 300],
82
+ [None, "https://pbs.twimg.com/media/Fx7YvfQWYAIp6rZ?format=jpg&name=medium", "What does the text in this image say?", 0.3, 300],
83
+ [None, "https://upload.wikimedia.org/wikipedia/commons/d/da/The_Parthenon_in_Athens.jpg", "Describe esta imagen en español", 0.3, 300]
84
+ ]
85
 
86
+ # Create Gradio application
87
+ with gr.Blocks(title="Aya Vision 8B Demo") as demo:
88
+ gr.Markdown("# Aya Vision 8B Model Demo")
89
+ gr.Markdown("""
90
+ This app demonstrates the C4AI Aya Vision 8B model, an 8-billion parameter vision-language model with capabilities including:
91
+ - OCR (reading text from images)
92
+ - Image captioning
93
+ - Visual reasoning
94
+ - Question answering
95
+ - Support for 23 languages
96
+
97
+ Upload an image or provide a URL, and enter a prompt to get started!
98
+ """)
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ load_button = gr.Button("Load Model", variant="primary")
103
+ status = gr.Textbox(label="Model Status", placeholder="Model not loaded yet. Click 'Load Model' to start.")
104
+
105
+ gr.Markdown("### Upload an image or provide an image URL:")
106
+ with gr.Tab("Upload Image"):
107
+ image_input = gr.Image(label="Upload Image", type="numpy")
108
+ image_url_input = gr.Textbox(label="Image URL", placeholder="Leave blank if uploading an image", visible=False)
109
+
110
+ with gr.Tab("Image URL"):
111
+ image_url_visible = gr.Textbox(label="Image URL", placeholder="Enter a URL to an image")
112
+ image_input_url = gr.Image(label="Upload Image", type="numpy", visible=False)
113
+
114
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt to the model", lines=3)
115
+
116
+ with gr.Accordion("Generation Settings", open=False):
117
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.3, label="Temperature")
118
+ max_tokens = gr.Slider(minimum=50, maximum=1000, step=50, value=300, label="Max Tokens")
119
+
120
+ generate_button = gr.Button("Generate Response", variant="primary")
121
+
122
+ with gr.Column():
123
+ output = gr.Textbox(label="Model Response", lines=10)
124
+
125
+ # Add examples section
126
+ gr.Markdown("### Examples")
127
+ gr.Examples(
128
+ examples=examples,
129
+ inputs=[image_input, image_url_visible, prompt, temperature, max_tokens],
130
+ outputs=output,
131
+ fn=process_image_and_prompt
132
+ )
133
 
134
+ # Set up tab switching logic - hide appropriate inputs depending on tab
135
+ def update_image_tab():
136
+ return {image_url_input: gr.update(visible=False), image_input: gr.update(visible=True)}
137
+
138
+ def update_url_tab():
139
+ return {image_url_visible: gr.update(visible=True), image_input_url: gr.update(visible=False)}
140
+
141
+ # Define button click behavior
142
+ load_button.click(load_model, inputs=None, outputs=status)
143
+
144
+ # Handle generation from either image or URL
145
+ def generate_response(image, image_url_visible, prompt, temperature, max_tokens):
146
+ return process_image_and_prompt(image, image_url_visible, prompt, temperature, max_tokens)
147
+
148
+ generate_button.click(
149
+ generate_response,
150
+ inputs=[image_input, image_url_visible, prompt, temperature, max_tokens],
151
+ outputs=output
152
+ )
153
 
154
+ # Launch the Gradio app
155
+ if __name__ == "__main__":
156
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ pillow>=9.0.0
5
+ requests>=2.28.0
6
+ numpy>=1.22.0
7
+ git+https://github.com/huggingface/[email protected]
8
+ accelerate>=0.20.0