Anupam251272 commited on
Commit
4ec1d8c
·
verified ·
1 Parent(s): e3d8d6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pdfplumber
4
+ import requests
5
+ import faiss
6
+ import json
7
+ import torch
8
+ from bs4 import BeautifulSoup
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
+ from sentence_transformers import SentenceTransformer
11
+ import numpy as np
12
+ import tempfile
13
+ import logging
14
+ from datetime import datetime
15
+ from typing import List, Dict
16
+
17
+ # Optimize CUDA memory management
18
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class CaseStudyGenerator:
25
+ def __init__(self):
26
+ self.model_name = "facebook/opt-2.7b"
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ # Clear any reserved memory
31
+ if self.device == "cuda":
32
+ torch.cuda.empty_cache()
33
+ torch.cuda.ipc_collect()
34
+
35
+ model_kwargs = {
36
+ 'torch_dtype': torch.float16 if self.device == "cuda" else torch.float32
37
+ }
38
+
39
+ try:
40
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **model_kwargs)
41
+ if self.device == "cuda":
42
+ self.model = self.model.to(self.device)
43
+ self.model.gradient_checkpointing_enable()
44
+ except RuntimeError as e:
45
+ logger.warning(f"Memory issue detected: {e}, attempting 8-bit loading.")
46
+
47
+ try:
48
+ from transformers import BitsAndBytesConfig
49
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
50
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config)
51
+ except ImportError:
52
+ logger.error("Missing 'bitsandbytes'. Install it using 'pip install -U bitsandbytes'")
53
+ logger.info("Switching to CPU to continue operations.")
54
+ self.device = "cpu"
55
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float32)
56
+
57
+ self.generator = pipeline(
58
+ "text-generation",
59
+ model=self.model,
60
+ tokenizer=self.tokenizer,
61
+ device=0 if self.device == "cuda" else -1,
62
+ max_length=2048,
63
+ num_return_sequences=1,
64
+ temperature=0.8,
65
+ top_p=0.95,
66
+ do_sample=True,
67
+ pad_token_id=self.tokenizer.eos_token_id
68
+ )
69
+
70
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
71
+ self.dimension = 384
72
+ self.index = faiss.IndexFlatL2(self.dimension)
73
+ self.stored_texts: List[Dict] = []
74
+
75
+ def clean_url(self, url: str) -> str:
76
+ if not url.startswith(('http://', 'https://')):
77
+ return ""
78
+ return url.split('?')[0][:100]
79
+
80
+ def fetch_articles(self, topic: str) -> List[str]:
81
+ try:
82
+ search_url = f"https://www.google.com/search?q={topic.replace(' ', '+')}+case+study+manufacturing+strategy"
83
+ headers = {"User-Agent": "Mozilla/5.0"}
84
+ response = requests.get(search_url, headers=headers, timeout=10)
85
+ response.raise_for_status()
86
+
87
+ soup = BeautifulSoup(response.text, "html.parser")
88
+ articles = [self.clean_url(link.get("href", "")) for link in soup.find_all("a") if "google" not in link.get("href", "")]
89
+ return articles[:5] or ["No articles found"]
90
+ except Exception as e:
91
+ logger.error(f"Error fetching articles: {str(e)}")
92
+ return ["Error fetching articles"]
93
+
94
+ def process_pdf(self, pdf_file) -> str:
95
+ try:
96
+ if pdf_file is None:
97
+ return "No PDF provided"
98
+
99
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
100
+ temp_pdf.write(pdf_file.read())
101
+ temp_path = temp_pdf.name
102
+
103
+ text = []
104
+ with pdfplumber.open(temp_path) as pdf:
105
+ text = [page.extract_text().strip() for page in pdf.pages if page.extract_text()]
106
+
107
+ os.unlink(temp_path)
108
+ return "\n".join(text) or "No text extracted from PDF"
109
+ except Exception as e:
110
+ logger.error(f"Error processing PDF: {str(e)}")
111
+ return "Error processing PDF"
112
+
113
+ def generate_case_study(self, topic: str, pdf=None) -> str:
114
+ try:
115
+ if self.device == "cuda":
116
+ torch.cuda.empty_cache()
117
+
118
+ articles = self.fetch_articles(topic)
119
+ pdf_text = self.process_pdf(pdf) if pdf else "No PDF provided"
120
+
121
+ prompt = f"""Write a professional case study about {topic}.
122
+ Background Information:
123
+ - Topic: {topic}
124
+ - Supporting Documents: {pdf_text[:500]}
125
+ - Related Sources: {', '.join(articles)}
126
+
127
+ Format your response as:
128
+ 1. Executive Summary
129
+ 2. Company Background
130
+ 3. Challenge Analysis
131
+ 4. Strategic Implementation
132
+ 5. Results and Impact
133
+ 6. Key Learnings
134
+ """
135
+
136
+ output = self.generator(
137
+ prompt,
138
+ max_new_tokens=1024,
139
+ num_return_sequences=1,
140
+ temperature=0.8,
141
+ top_p=0.95,
142
+ do_sample=True,
143
+ repetition_penalty=1.2,
144
+ no_repeat_ngram_size=3
145
+ )
146
+
147
+ case_study = output[0]['generated_text'].replace(prompt, "").strip()
148
+ embedding = self.embedding_model.encode([case_study])[0]
149
+ self.index.add(embedding.reshape(1, -1))
150
+
151
+ self.stored_texts.append({
152
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
153
+ "topic": topic,
154
+ "content": case_study
155
+ })
156
+
157
+ return case_study
158
+ except Exception as e:
159
+ logger.error(f"Error generating case study: {str(e)}")
160
+ return f"Error generating case study: {str(e)}"
161
+
162
+ def retrieve_past_case_studies(self) -> str:
163
+ try:
164
+ if not self.stored_texts:
165
+ return "No case studies generated yet."
166
+
167
+ result = ""
168
+ for idx, case in enumerate(self.stored_texts[-5:], start=1):
169
+ result += f"Case Study {idx}\nTopic: {case['topic']}\nGenerated on: {case['timestamp']}\n\n{case['content']}\n\n=== End of Case Study ===\n\n"
170
+ return result
171
+ except Exception as e:
172
+ logger.error(f"Error retrieving past case studies: {str(e)}")
173
+ return "Error retrieving past case studies"
174
+
175
+ # Gradio interface
176
+ with gr.Blocks() as app:
177
+ gr.Markdown("# AI Case Study Generator (Optimized for GPU-T4 & CPU)")
178
+ with gr.Row():
179
+ topic = gr.Textbox(label="Enter Topic")
180
+ pdf = gr.File(label="Upload PDF", type="binary")
181
+ with gr.Row():
182
+ generate_btn = gr.Button("Generate Case Study")
183
+ retrieve_btn = gr.Button("Retrieve Past Case Studies")
184
+ output = gr.Textbox(label="Generated Case Study", lines=20)
185
+ past_cases = gr.Textbox(label="Past Case Studies", lines=20)
186
+
187
+ generator = CaseStudyGenerator()
188
+ generate_btn.click(generator.generate_case_study, inputs=[topic, pdf], outputs=output)
189
+ retrieve_btn.click(generator.retrieve_past_case_studies, outputs=past_cases)
190
+
191
+ # Launch the application
192
+ if __name__ == "__main__":
193
+ app.launch(share=True) # Remove enable_queue
194
+ # or, If using Gradio 3.x or later, use:
195
+ # app.queue().launch(share=True)