Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import os | |
import spaces | |
import torch | |
from dotenv import load_dotenv | |
from huggingface_hub import login, snapshot_download | |
from superposed.llama.superposed_generation import SuperposedLlama | |
from superposed.llama.tokenizer import Tokenizer | |
from superposed.ngrams.ngram_models import make_models | |
# Set torch dist variables | |
os.environ['RANK'] = "0" | |
os.environ['WORLD_SIZE'] = "1" | |
os.environ['MASTER_PORT'] = "12193" | |
os.environ['MASTER_ADDR'] = "127.0.0.1" | |
def load_models(): | |
model = SuperposedLlama.build(ckpt_dir=weight_path, | |
tokenizer_path=f'{weight_path}/tokenizer.model', | |
max_seq_len=100, | |
max_batch_size=32, | |
device="cuda", | |
model_parallel_size=1) | |
return model | |
# load_dotenv() | |
# print(os.getenv("HF_ACCESS_TOKEN")) | |
login(os.getenv("HF_ACCESS_TOKEN")) | |
if not os.path.exists("./weights/"): | |
os.mkdir("./weights/") | |
snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/") | |
weight_path = "./weights/" | |
# Load params | |
param_file = "params/p15_d3_ngram4_mixed.json" | |
with open(param_file, "r") as f: | |
params = json.load(f) | |
alpha = params["alpha"] | |
temp = params["temp"] | |
n_drafts = params["n_drafts"] | |
prompt_len = params["prompt_len"] | |
n_token_sample = params["n_token_sample"] | |
i_weights = params["i_weights"] | |
i_length = params["i_length"] | |
# Load main model | |
model = load_models() | |
tokenizer = Tokenizer(f'{weight_path}/tokenizer.model') | |
# Create ngram models | |
ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=False, sixgram=False, sevengram=False) | |
def decode(tokenizer, encoding): | |
""" | |
Args: | |
tokenizer (Any): Tokenizer | |
encoding (torch.Tensor): Encoding | |
Returns: | |
decoding (str) | |
""" | |
eos_locs = (encoding == tokenizer.eos_id).nonzero() | |
if len(eos_locs > 0): | |
encoding = encoding[:eos_locs[0]] | |
return tokenizer.decode(encoding.to(torch.int32).tolist()) | |
def update_options(input, num_tokens): | |
tokenized_prompts = tokenizer.encode([input], True, False) | |
print("Processed prompt") | |
model.model.to("cuda") | |
model.model.device = "cuda" | |
alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, | |
smoothing="geom", | |
max_gen_len=num_tokens, | |
n_token_sample=n_token_sample, | |
alpha=alpha, | |
temp=temp, | |
n_drafts=n_drafts, | |
i_weights=i_weights, | |
i_length=i_length, | |
ngrams=ngrams, | |
get_time=False, | |
penalty=200) | |
print("Generated") | |
gens = alive_gens[0].reshape(n_drafts, -1) | |
return decode(tokenizer, gens[0])[len(input):], decode(tokenizer, gens[1])[len(input):], decode(tokenizer, gens[2])[len(input):] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Superposed Decoding | |
Start typing below to see suggestions.\n | |
Disclaimer: This demo only uses $n=\{2, 3, 4\}$ $n$-grams as opposed to $n=\{2, 3, 4, 5, 6}\$ in the paper. In addition, there may be significant latency at times because a GPU must be re-aquired after every change. | |
\n | |
Paper: [https://arxiv.org/abs/2405.18400](https://arxiv.org/abs/2405.18400)\n | |
Code: [https://github.com/RAIVNLab/SuperposedDecoding](https://github.com/RAIVNLab/SuperposedDecoding) | |
""") | |
slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10) | |
inp = gr.Textbox(placeholder="Type anything!", lines=3) | |
option1 = gr.Button(value="Option 1") | |
option2 = gr.Button(value="Option 2") | |
option3 = gr.Button(value="Option 3") | |
inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3]) | |
# Button updates | |
def option1_click(curr, txt): | |
return curr + txt | |
def option2_click(curr, txt): | |
return curr + txt | |
def option3_click(curr, txt): | |
return curr + txt | |
if __name__ == "__main__": | |
demo.launch(share=True) |