Spaces:
Configuration error
Configuration error
Commit
·
5b81931
1
Parent(s):
928c735
Upload 14 files
Browse files- .env +1 -0
- .gitattributes +7 -34
- README.md +2 -12
- __pycache__/process.cpython-37.pyc +0 -0
- amharic.csv +3 -0
- app.py +51 -0
- hausa.csv +3 -0
- igbo.csv +0 -0
- news.ann +3 -0
- process.py +133 -0
- requirements.txt +16 -0
- swahili.csv +3 -0
- utils.py +32 -0
- yoruba.csv +3 -0
.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
COHERE_API_KEY = 7rMjNpj7LLTNlAcoR1Sc6cH23aURrBQoMPi9vzam
|
.gitattributes
CHANGED
@@ -1,34 +1,7 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
3 |
+
amharic.csv filter=lfs diff=lfs merge=lfs -text
|
4 |
+
hausa.csv filter=lfs diff=lfs merge=lfs -text
|
5 |
+
news.ann filter=lfs diff=lfs merge=lfs -text
|
6 |
+
swahili.csv filter=lfs diff=lfs merge=lfs -text
|
7 |
+
yoruba.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,12 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 🔥
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: green
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.19.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# cluster_news
|
2 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/process.cpython-37.pyc
ADDED
Binary file (3.25 kB). View file
|
|
amharic.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59b8670c19f95f0cff667b8d5f69033e93bcdd2dec5e1cc069f82d93699da894
|
3 |
+
size 36144176
|
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from process import *
|
4 |
+
|
5 |
+
df = import_ds()
|
6 |
+
|
7 |
+
st.title('AFri News Multilingual Embedding')
|
8 |
+
|
9 |
+
form = st.form(key="user_settings")
|
10 |
+
|
11 |
+
textcontainer = st.container()
|
12 |
+
|
13 |
+
plotcontainer = st.container()
|
14 |
+
|
15 |
+
with form:
|
16 |
+
|
17 |
+
query = st.text_input('Please input your news text here:')
|
18 |
+
|
19 |
+
num_nearest = int(st.slider('Please input the number of news to find: ', value=15, min_value=1, max_value=200))
|
20 |
+
|
21 |
+
generate_button = form.form_submit_button("Cluster News")
|
22 |
+
|
23 |
+
if generate_button:
|
24 |
+
key = get_key()
|
25 |
+
|
26 |
+
co = cohere.Client(key)
|
27 |
+
|
28 |
+
embeddings = getEmbeddings(co,df)
|
29 |
+
|
30 |
+
indexfile = 'news.ann'
|
31 |
+
|
32 |
+
semantic_search(embeddings, indexfile)
|
33 |
+
|
34 |
+
query_embed = get_query_embed(co, query)
|
35 |
+
|
36 |
+
nearest_ids = getClosestNeighbours(indexfile, query_embed, num_nearest)
|
37 |
+
|
38 |
+
nn_embeddings = embeddings[nearest_ids[0]]
|
39 |
+
|
40 |
+
all_embeddings = np.vstack([nn_embeddings, query_embed])
|
41 |
+
|
42 |
+
umap_embeds = getUMAPEmbed(embeddings)
|
43 |
+
|
44 |
+
text_news = display_news(df,nearest_ids)
|
45 |
+
|
46 |
+
fig = plot2DChart(df, umap_embeds)
|
47 |
+
|
48 |
+
textcontainer.write(text_news)
|
49 |
+
|
50 |
+
plotcontainer.write(fig)
|
51 |
+
|
hausa.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c5279476f52eded50fa5254c9a6be01abe1393484eb57a8858f90c6d079e520e
|
3 |
+
size 14590027
|
igbo.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
news.ann
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71443c486fb4dc39f3a600b705642795ad19c8ebca8e495259790e5351610b74
|
3 |
+
size 1603680
|
process.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#from dotenv import load_dotenv
|
2 |
+
from annoy import AnnoyIndex
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import cohere
|
6 |
+
import os
|
7 |
+
import plotly.express as px
|
8 |
+
import umap
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
|
11 |
+
|
12 |
+
def get_key():
|
13 |
+
key = "7rMjNpj7LLTNlAcoR1Sc6cH23aURrBQoMPi9vzam"
|
14 |
+
#load_dotenv()
|
15 |
+
return key
|
16 |
+
|
17 |
+
|
18 |
+
def import_ds():
|
19 |
+
newsfiles = ['amharic','hausa','swahili','yoruba','igbo']
|
20 |
+
|
21 |
+
df_am = pd.read_csv(f'{newsfiles[0]}.csv')
|
22 |
+
df_am = df_am.sample(frac=0.5)
|
23 |
+
#df_en = pd.read_csv(f'{newsfiles[1]}.csv')
|
24 |
+
#df_en = df_en.sample(frac=0.3)
|
25 |
+
df_hs = pd.read_csv(f'{newsfiles[1]}.csv')
|
26 |
+
df_hs = df_hs.sample(frac=0.5)
|
27 |
+
df_sw = pd.read_csv(f'{newsfiles[2]}.csv')
|
28 |
+
df_sw = df_sw.sample(frac=0.5)
|
29 |
+
df_yr = pd.read_csv(f'{newsfiles[3]}.csv')
|
30 |
+
df_yr = df_yr.sample(frac=0.5)
|
31 |
+
df_ig = pd.read_csv(f'{newsfiles[4]}.csv')
|
32 |
+
df_ig = df_ig.sample(frac=0.5)
|
33 |
+
|
34 |
+
df_news = pd.concat([df_am,df_hs,df_sw,df_yr,df_ig],axis=0)
|
35 |
+
|
36 |
+
df_news = df_news.sample(frac = 1)
|
37 |
+
|
38 |
+
df_news = df_news[df_news['title'].notna()]
|
39 |
+
|
40 |
+
df_news = df_news.drop_duplicates("title")
|
41 |
+
|
42 |
+
df_news = df_news.sample(500)
|
43 |
+
|
44 |
+
return df_news
|
45 |
+
|
46 |
+
|
47 |
+
def getEmbeddings(co,df):
|
48 |
+
|
49 |
+
df['text'] = df['title'] + df['summary']
|
50 |
+
|
51 |
+
df = df.drop(['title','id','summary'],axis=1)
|
52 |
+
|
53 |
+
embeds = co.embed(texts=list(df['text']),model="multilingual-22-12",truncate="RIGHT").embeddings
|
54 |
+
|
55 |
+
embeds = np.array(embeds)
|
56 |
+
|
57 |
+
return embeds
|
58 |
+
|
59 |
+
def semantic_search(emb,indexfile):
|
60 |
+
|
61 |
+
emb = np.array(emb)
|
62 |
+
|
63 |
+
search_index = AnnoyIndex(emb.shape[1], 'angular')
|
64 |
+
print(emb.shape[1])
|
65 |
+
|
66 |
+
for i in range(len(emb)):
|
67 |
+
search_index.add_item(i, emb[i])
|
68 |
+
|
69 |
+
search_index.build(10)
|
70 |
+
search_index.save(indexfile)
|
71 |
+
|
72 |
+
def get_query_embed(co, query):
|
73 |
+
query_embed = co.embed(texts=[query],
|
74 |
+
model='multilingual-22-12',
|
75 |
+
truncate='right').embeddings
|
76 |
+
|
77 |
+
return np.array(query_embed)
|
78 |
+
|
79 |
+
def getClosestNeighbours(indexfile,query_embed,neighbours=15):
|
80 |
+
|
81 |
+
search_index = AnnoyIndex(768, 'angular')
|
82 |
+
search_index.load(indexfile)
|
83 |
+
|
84 |
+
|
85 |
+
# Retrieve the nearest neighbors
|
86 |
+
similar_item_ids = search_index.get_nns_by_vector(query_embed[0],neighbours,
|
87 |
+
include_distances=True)
|
88 |
+
|
89 |
+
return similar_item_ids
|
90 |
+
|
91 |
+
def display_news(df,similar_item_ids):
|
92 |
+
# Format the results
|
93 |
+
#print(similar_item_ids)
|
94 |
+
|
95 |
+
results = pd.DataFrame(data={'title': df.iloc[similar_item_ids[0]]['title'],
|
96 |
+
'url': df.iloc[similar_item_ids[0]]['url'],
|
97 |
+
'summary': df.iloc[similar_item_ids[0]]['summary']})
|
98 |
+
#'distance': similar_item_ids[1]})
|
99 |
+
results.reset_index(drop=True, inplace=True)
|
100 |
+
|
101 |
+
return results
|
102 |
+
|
103 |
+
def getUMAPEmbed(embeds):
|
104 |
+
# Map the nearest embeddings to 2d
|
105 |
+
reducer = umap.UMAP(n_neighbors=20)
|
106 |
+
|
107 |
+
return reducer.fit_transform(embeds)
|
108 |
+
|
109 |
+
|
110 |
+
def plot2DChart(df, umap_embeds, clusters=None):
|
111 |
+
if clusters is None:
|
112 |
+
clusters = {}
|
113 |
+
|
114 |
+
df_viz = pd.DataFrame(data={'url': df['url'], 'title': df['title']})
|
115 |
+
df_viz['x'] = umap_embeds[:, 0]
|
116 |
+
df_viz['y'] = umap_embeds[:, 1]
|
117 |
+
|
118 |
+
#print(df_explore)
|
119 |
+
# Plot
|
120 |
+
fig = px.scatter(df_viz, x='x', y='y', hover_data=['title'])
|
121 |
+
|
122 |
+
|
123 |
+
fig.data = fig.data[::-1]
|
124 |
+
|
125 |
+
return fig
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
key = get_key()
|
129 |
+
co = cohere.Client(key)
|
130 |
+
df_news = import_ds()
|
131 |
+
embed = process(co,df_news)
|
132 |
+
semantic_search(embed)
|
133 |
+
getClosestNeighbours(df_news)
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==4.2.2
|
2 |
+
annoy==1.17.0
|
3 |
+
huggingface-hub==0.14.1
|
4 |
+
numpy==1.21.6
|
5 |
+
pandas==1.3.5
|
6 |
+
plotly==5.14.1
|
7 |
+
scipy==1.7.3
|
8 |
+
beautifulsoup4==4.11.1
|
9 |
+
cohere==2.7.0
|
10 |
+
matplotlib==3.5.1
|
11 |
+
python-dotenv==0.21.0
|
12 |
+
scikit_learn==1.0.2
|
13 |
+
streamlit==1.22.0
|
14 |
+
streamlit_plotly_events==0.0.6
|
15 |
+
umap==0.1.1
|
16 |
+
umap_learn==0.5.3
|
swahili.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bccf0a9aaa7f5399fa51b6d34df9848f5a077a771ce2318f7f6beb58686dee99
|
3 |
+
size 20901981
|
utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset_builder, load_dataset
|
2 |
+
import logging
|
3 |
+
|
4 |
+
def inspect():
|
5 |
+
langs = ['amharic','english','hausa','swahili','yoruba','igbo']
|
6 |
+
|
7 |
+
for lang in langs:
|
8 |
+
ds_builder = load_dataset_builder("csebuetnlp/xlsum",lang)
|
9 |
+
|
10 |
+
desc = ds_builder.info.description
|
11 |
+
|
12 |
+
feat = ds_builder.info.features
|
13 |
+
|
14 |
+
return desc,feat
|
15 |
+
|
16 |
+
def load():
|
17 |
+
try:
|
18 |
+
langs = ['amharic','hausa','swahili','yoruba','igbo']
|
19 |
+
|
20 |
+
for lang in langs:
|
21 |
+
|
22 |
+
dataset = load_dataset("csebuetnlp/xlsum", lang ,split="train")
|
23 |
+
#for split, data in dataset.items():
|
24 |
+
dataset.to_csv(f"{lang}.csv", index = None)
|
25 |
+
#dataset.save_to_disk(lang)
|
26 |
+
#return dataset
|
27 |
+
except Exception as ex:
|
28 |
+
logging.debug(ex)
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
#print(inspect())
|
32 |
+
load()
|
yoruba.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f5df52e87acfcd2fae999e7108a4f8c5e44345070b3c41380d72c47f8fd1412
|
3 |
+
size 16448886
|