tingyuansen commited on
Commit
6dbd035
·
verified ·
1 Parent(s): 7e3854e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +19 -9
README.md CHANGED
@@ -44,23 +44,33 @@ model = AutoModelForCausalLM.from_pretrained("AstroMLab/astrollama-2-7b-chat_aic
44
 
45
  # Function to generate a response
46
  def generate_response(prompt, max_length=512):
47
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
 
48
  inputs = inputs.to(model.device)
49
 
50
  # Generate a response
51
  with torch.no_grad():
52
- outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, do_sample=True)
 
 
 
 
 
 
 
53
 
54
  # Decode and return the response
55
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- return response
 
 
 
57
 
58
- # Example conversation
59
  user_input = "What are the main components of a galaxy?"
60
- prompt = f"Human: {user_input}\n\nAssistant:"
61
-
62
- response = generate_response(prompt)
63
- print(response)
64
  ```
65
 
66
  ## Model Limitations and Biases
 
44
 
45
  # Function to generate a response
46
  def generate_response(prompt, max_length=512):
47
+ full_prompt = f"###Human: {prompt}\n\n###Assistant:"
48
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=max_length)
49
  inputs = inputs.to(model.device)
50
 
51
  # Generate a response
52
  with torch.no_grad():
53
+ outputs = model.generate(
54
+ **inputs,
55
+ max_length=max_length,
56
+ num_return_sequences=1,
57
+ do_sample=True,
58
+ pad_token_id=tokenizer.eos_token_id,
59
+ eos_token_id=tokenizer.encode("###Human:", add_special_tokens=False)[0]
60
+ )
61
 
62
  # Decode and return the response
63
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
64
+
65
+ # Extract only the Assistant's response
66
+ assistant_response = response.split("###Assistant:")[-1].strip()
67
+ return assistant_response
68
 
69
+ # Example usage
70
  user_input = "What are the main components of a galaxy?"
71
+ response = generate_response(user_input)
72
+ print(f"Human: {user_input}")
73
+ print(f"Assistant: {response}")
 
74
  ```
75
 
76
  ## Model Limitations and Biases