DeepGit-lite / main.py
zamalali
Refactor app.py and main.py for improved readability and functionality; add environment variable loading
94ed277
raw
history blame contribute delete
15.4 kB
import os
import base64
import requests
import numpy as np
import faiss
import re
import logging
from pathlib import Path
# For local development, load environment variables from a .env file.
# In HuggingFace Spaces, secrets are automatically available as environment variables.
from dotenv import load_dotenv
load_dotenv()
from sentence_transformers import SentenceTransformer, CrossEncoder
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
# Optionally import BM25 for sparse retrieval.
try:
from rank_bm25 import BM25Okapi
except ImportError:
BM25Okapi = None
# ---------------------------
# Environment Variables & Setup
# ---------------------------
# GitHub API key (required for GitHub API calls)
GITHUB_API_KEY = os.getenv("GITHUB_API_KEY")
# GROQ API key (if required by ChatGroq)
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# HuggingFace token (if you need it to load private models from HuggingFace)
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
# Set up a persistent session for GitHub API requests.
session = requests.Session()
session.headers.update({
"Authorization": f"token {GITHUB_API_KEY}",
"Accept": "application/vnd.github.v3+json"
})
# ---------------------------
# Langchain Groq Setup for Search Tag Conversion
# ---------------------------
llm = ChatGroq(
model="deepseek-r1-distill-llama-70b",
temperature=0.3,
max_tokens=512,
max_retries=3,
api_key=GROQ_API_KEY # Pass GROQ_API_KEY if the ChatGroq library supports it.
)
prompt = ChatPromptTemplate.from_messages([
("system",
"""You are a GitHub search optimization expert.
Your job is to:
1. Read a user's query about tools, research, or tasks.
2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language.
3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags.
4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery.
Use as many tags as necessary based on the query's complexity, but never more than five.
5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript).
If no specific language is mentioned, do not include any target tag.
Output Format:
tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]
Rules:
- Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought).
- Use terms commonly found in GitHub repo names, topics, or descriptions.
- Avoid generic terms like "python", "ai", "tool", "project".
- Do NOT use full phrases or vague words like "no-code", "framework", or "approach".
- Prefer real tools, popular methods, or dataset names when mentioned.
- If your output does not strictly match the required format, correct it after your internal reasoning.
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
Excellent Examples:
Input: "No code tool to augment image and annotation"
Output: image-augmentation:albumentations
Input: "Repos around chain of thought prompting mainly for finetuned models"
Output: chain-of-thought:finetuned-llm
Input: "Find repositories implementing data augmentation pipelines in JavaScript"
Output: data-augmentation:target-javascript
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
"""),
("human", "{query}")
])
chain = prompt | llm
def valid_tags(tags: str) -> bool:
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
return re.match(pattern, tags) is not None
def parse_search_tags(response: str) -> str:
# Remove any text inside <think>...</think> blocks.
cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
match = re.search(pattern, cleaned)
if match:
return match.group(1).strip()
return cleaned.strip()
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}")
refined_query = query
tags_output = ""
for iteration in range(max_iterations):
print(f"\n🔄 Iteration {iteration+1}")
response = chain.invoke({"query": refined_query})
full_output = response.content.strip()
tags_output = parse_search_tags(full_output)
print(f"Output Tags: {tags_output}")
if valid_tags(tags_output):
print("✅ Valid tags format detected.")
return tags_output
else:
print("⚠️ Invalid tags format. Requesting refinement...")
refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]."
print("Final output (may be invalid):", tags_output)
return tags_output
# ---------------------------
# GitHub API Helper Functions
# ---------------------------
def fetch_readme_content(repo_full_name: str) -> str:
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
response = session.get(readme_url)
if response.status_code == 200:
readme_data = response.json()
try:
return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace')
except Exception:
return ""
return ""
def fetch_markdown_contents(repo_full_name: str) -> str:
url = f"https://api.github.com/repos/{repo_full_name}/contents"
response = session.get(url)
contents = ""
if response.status_code == 200:
items = response.json()
for item in items:
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
file_url = item.get("download_url")
if file_url:
file_resp = requests.get(file_url)
if file_resp.status_code == 200:
contents += "\n" + file_resp.text
return contents
def fetch_all_markdown(repo_full_name: str) -> str:
readme = fetch_readme_content(repo_full_name)
other_md = fetch_markdown_contents(repo_full_name)
return readme + "\n" + other_md
def fetch_github_repositories(query: str, max_results: int = 10) -> list:
url = "https://api.github.com/search/repositories"
params = {
"q": query,
"per_page": max_results
}
response = session.get(url, params=params)
if response.status_code != 200:
print(f"Error {response.status_code}: {response.json().get('message')}")
return []
repo_list = []
for repo in response.json().get('items', []):
repo_link = repo.get('html_url')
description = repo.get('description') or ""
combined_markdown = fetch_all_markdown(repo.get('full_name'))
combined_text = (description + "\n" + combined_markdown).strip()
repo_list.append({
"title": repo.get('name', 'No title available'),
"link": repo_link,
"combined_text": combined_text
})
return repo_list
# ---------------------------
# Dense Retrieval Model Setup
# ---------------------------
try:
# If using a GPU-enabled model, the HuggingFace token can be used for private models.
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')
except Exception as e:
print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e)
model = SentenceTransformer('all-mpnet-base-v2', device='cpu')
def robust_min_max_norm(scores: np.ndarray) -> np.ndarray:
min_val = scores.min()
max_val = scores.max()
if max_val - min_val < 1e-10:
return np.ones_like(scores)
return (scores - min_val) / (max_val - min_val)
# ---------------------------
# Cross-Encoder Re-Ranking Function
# ---------------------------
def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list:
try:
cross_encoder = CrossEncoder(model_name, device='cuda')
except Exception as e:
print("Error initializing CrossEncoder on GPU; falling back to CPU:", e)
cross_encoder = CrossEncoder(model_name, device='cpu')
CHUNK_SIZE = 2000
MAX_DOC_LENGTH = 5000
MIN_DOC_LENGTH = 200
def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
for candidate in candidates:
doc = candidate.get("combined_text", "")
if len(doc) > MAX_DOC_LENGTH:
doc = doc[:MAX_DOC_LENGTH]
try:
if len(doc) < MIN_DOC_LENGTH:
score = cross_encoder.predict([[query, doc]])
candidate["cross_encoder_score"] = float(score[0])
else:
chunks = split_text(doc)
pairs = [[query, chunk] for chunk in chunks]
scores = cross_encoder.predict(pairs)
max_score = np.max(scores) if scores else 0.0
avg_score = np.mean(scores) if scores else 0.0
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
except Exception as e:
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
candidate["cross_encoder_score"] = 0.0
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
if all_scores:
min_score = min(all_scores)
if min_score < 0:
for candidate in candidates:
candidate["cross_encoder_score"] += -min_score
return candidates
# ---------------------------
# Main Ranking Function with Hybrid Retrieval and Combined Scoring
# ---------------------------
def run_repository_ranking(query: str) -> str:
# Step 1: Generate search tags from the query.
search_tags = iterative_convert_to_search_tags(query)
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
# Step 2: Handle target language extraction.
if any(tag.startswith("target-") for tag in tag_list):
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
lang_query = f"language:{target_tag.replace('target-', '')}"
tag_list = [tag for tag in tag_list if not tag.startswith("target-")]
else:
lang_query = "language:python"
# Step 3: Build advanced search qualifiers.
advanced_qualifier = "in:name,description,readme"
all_repositories = []
for tag in tag_list:
github_query = f"{tag} {advanced_qualifier} {lang_query}"
print("GitHub Query:", github_query)
repos = fetch_github_repositories(github_query, max_results=15)
all_repositories.extend(repos)
combined_query = " OR ".join(tag_list)
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
print("Combined GitHub Query:", combined_query)
repos = fetch_github_repositories(combined_query, max_results=15)
all_repositories.extend(repos)
unique_repositories = {}
for repo in all_repositories:
if repo["link"] not in unique_repositories:
unique_repositories[repo["link"]] = repo
else:
existing_text = unique_repositories[repo["link"]]["combined_text"]
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
repositories = list(unique_repositories.values())
if not repositories:
return "No repositories found for your query."
# Step 4: Prepare documents.
docs = [repo.get("combined_text", "") for repo in repositories]
# Step 5: Dense retrieval.
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
if doc_embeddings.ndim == 1:
doc_embeddings = doc_embeddings.reshape(1, -1)
norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
norm_doc_embeddings = doc_embeddings / (norms + 1e-10)
query_embedding = model.encode(query, convert_to_numpy=True)
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10)
dim = norm_doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(norm_doc_embeddings)
k = norm_doc_embeddings.shape[0]
D, I = index.search(norm_query_embedding, k)
dense_scores = D.squeeze()
norm_dense_scores = robust_min_max_norm(dense_scores)
# Step 6: BM25 scoring.
if BM25Okapi is not None:
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
bm25 = BM25Okapi(tokenized_docs)
query_tokens = re.findall(r'\w+', query.lower())
bm25_scores = np.array(bm25.get_scores(query_tokens))
norm_bm25_scores = robust_min_max_norm(bm25_scores)
else:
norm_bm25_scores = np.zeros_like(norm_dense_scores)
# Step 7: Combine scores (dense score weighted higher).
alpha = 0.8
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
for idx, repo in enumerate(repositories):
repo["combined_score"] = float(combined_scores[idx])
# Step 8: Initial ranking by combined score.
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
# Step 9: Compute cross-encoder scores for the top candidates.
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates))
# Combine both metrics: final_score = w1 * combined_score + w2 * cross_encoder_score.
w1 = 0.7
w2 = 0.3
for candidate in top_candidates:
candidate["final_score"] = w1 * candidate.get("combined_score", 0) + w2 * candidate.get("cross_encoder_score", 0)
final_ranked = sorted(top_candidates, key=lambda x: x.get("final_score", 0), reverse=True)[:10]
# Step 10: Format final output with scores as percentages.
output = "\n=== Ranked Repositories ===\n"
for rank, repo in enumerate(final_ranked, 1):
output += f"Final Rank: {rank}\n"
output += f"Title: {repo['title']}\n"
output += f"Link: {repo['link']}\n"
output += f"Combined Score: {repo.get('combined_score', 0) * 100:.2f}%\n"
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0) * 100:.2f}%\n"
output += f"Final Score: {repo.get('final_score', 0) * 100:.2f}%\n"
snippet = repo['combined_text'][:300].replace('\n', ' ')
output += f"Snippet: {snippet}...\n"
output += '-' * 80 + "\n"
output += "\n=== End of Results ==="
return output
# ---------------------------
# Main Entry Point for Testing
# ---------------------------
if __name__ == "__main__":
test_query = "Chain of thought prompting for reasoning models"
result = run_repository_ranking(test_query)
print(result)