Update README.md
Browse files
README.md
CHANGED
@@ -23,30 +23,75 @@ Took **28 hours** to finetune on **2x Nvidia RTX A6000** with the following sett
|
|
23 |
Run the model:
|
24 |
```python
|
25 |
import torch
|
26 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
model_id = "CreitinGameplays/Llama-3.1-8B-R1-v0.1"
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
torch_dtype=torch.bfloat16,
|
34 |
-
device_map="auto"
|
|
|
35 |
)
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
|
|
50 |
```
|
51 |
|
52 |
### Current Limitations
|
|
|
23 |
Run the model:
|
24 |
```python
|
25 |
import torch
|
26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig
|
27 |
+
import bitsandbytes
|
28 |
+
|
29 |
+
quantization_config = BitsAndBytesConfig(
|
30 |
+
load_in_8bit=True,
|
31 |
+
llm_int8_enable_fp32_cpu_offload=True
|
32 |
+
)
|
33 |
|
34 |
model_id = "CreitinGameplays/Llama-3.1-8B-R1-v0.1"
|
35 |
|
36 |
+
# Initialize model and tokenizer with streaming support
|
37 |
+
model = AutoModelForCausalLM.from_pretrained(
|
38 |
+
model_id,
|
39 |
torch_dtype=torch.bfloat16,
|
40 |
+
device_map="auto",
|
41 |
+
quantization_config=quantization_config
|
42 |
)
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
44 |
|
45 |
+
# Custom streamer that collects the output into a string while streaming
|
46 |
+
class CollectingStreamer(TextStreamer):
|
47 |
+
def __init__(self, tokenizer):
|
48 |
+
super().__init__(tokenizer)
|
49 |
+
self.output = ""
|
50 |
+
def on_llm_new_token(self, token: str, **kwargs):
|
51 |
+
self.output += token
|
52 |
+
print(token, end="", flush=True) # prints the token as it's generated
|
53 |
|
54 |
+
print("Chat session started. Type 'exit' to quit.\n")
|
55 |
+
|
56 |
+
# Initialize chat history as a list of messages
|
57 |
+
chat_history = []
|
58 |
+
chat_history.append({"role": "system", "content": "You are an AI assistant made by Meta AI."})
|
59 |
+
|
60 |
+
while True:
|
61 |
+
user_input = input("You: ")
|
62 |
+
if user_input.strip().lower() == "exit":
|
63 |
+
break
|
64 |
+
|
65 |
+
# Append the user message to the chat history
|
66 |
+
chat_history.append({"role": "user", "content": user_input})
|
67 |
+
|
68 |
+
# Prepare the prompt by formatting the complete chat history
|
69 |
+
inputs = tokenizer.apply_chat_template(
|
70 |
+
chat_history,
|
71 |
+
return_tensors="pt"
|
72 |
+
).to(model.device)
|
73 |
+
|
74 |
+
# Create a new streamer for the current generation
|
75 |
+
streamer = CollectingStreamer(tokenizer)
|
76 |
+
|
77 |
+
# Generate streamed response
|
78 |
+
model.generate(
|
79 |
+
inputs,
|
80 |
+
streamer=streamer,
|
81 |
+
temperature=0.6,
|
82 |
+
top_p=0.9,
|
83 |
+
top_k=50,
|
84 |
+
repetition_penalty=1.1,
|
85 |
+
max_new_tokens=6112,
|
86 |
+
do_sample=True
|
87 |
+
)
|
88 |
+
|
89 |
+
# The complete response text is stored in streamer.output
|
90 |
+
response_text = streamer.output
|
91 |
+
print("\nAssistant:", response_text)
|
92 |
|
93 |
+
# Append the assistant response to the chat history
|
94 |
+
chat_history.append({"role": "assistant", "content": response_text})
|
95 |
```
|
96 |
|
97 |
### Current Limitations
|