Jintonic92 commited on
Commit
4be3af4
ยท
verified ยท
1 Parent(s): 97ccdbd

Update src/ThirdModule/module3.py

Browse files
Files changed (1) hide show
  1. src/ThirdModule/module3.py +70 -100
src/ThirdModule/module3.py CHANGED
@@ -1,121 +1,91 @@
1
  # module3.py
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from typing import Tuple
5
  import logging
6
- from config import Llama3_8b_PATH
7
- import re
8
 
9
- logger = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
 
11
 
12
- class SelfConsistencyChecker:
13
- def __init__(self, model_name: str = 'meta-llama/Meta-Llama-3-8B-Instruct'):
14
- self._load_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def _load_model(self, model_name: str):
17
- """Load the language model for self-consistency checking."""
18
- logger.info(f"Loading model '{model_name}' from '{Llama3_8b_PATH}' for self-consistency check...")
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=Llama3_8b_PATH, trust_remote_code=True)
20
- self.model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- cache_dir=Llama3_8b_PATH,
23
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
24
- trust_remote_code=True,
25
- device_map="auto"
26
- )
27
- self.model.eval()
28
- if torch.cuda.is_available():
29
- self.model.to('cuda')
30
- logger.info("Model loaded on GPU for self-consistency.")
31
- else:
32
- logger.info("Model loaded on CPU for self-consistency.")
33
 
34
  def _create_prompt(self, question: str, choices: dict) -> str:
35
- """
36
- Create a prompt following the Llama 3 prompt template.
37
- """
38
- prompt = f"""
39
  <|begin_of_text|>
40
  <|start_header_id|>system<|end_header_id|>
41
- You are an expert reasoning assistant. Your task is to determine the single most accurate answer (A, B, C, or D) for a multiple-choice question based on the given options.
42
-
43
- Rules:
44
- 1. Carefully read the question and all options.
45
- 2. Use logical reasoning to select the best answer.
46
- 3. Output your answer strictly in the following format: "Answer: [A/B/C/D]"
47
- 4. Do not provide any explanation or extra information.
48
-
49
  <|eot_id|>
50
  <|start_header_id|>user<|end_header_id|>
51
  Question: {question}
52
 
53
- Choices:
54
  A) {choices['A']}
55
  B) {choices['B']}
56
  C) {choices['C']}
57
  D) {choices['D']}
58
 
59
- Please select the correct answer.
60
  <|eot_id|>
61
  <|start_header_id|>assistant<|end_header_id|>
62
- """
63
- return prompt.strip()
64
-
65
- def _extract_answer(self, text: str) -> str:
66
- """
67
- Extract the answer (A, B, C, or D) from the generated text.
68
- """
69
- match = re.search(r"Answer:\s*([ABCD])", text, re.IGNORECASE)
70
- if match:
71
- answer = match.group(1).upper()
72
- logger.info(f"Extracted answer: {answer} from text: {text}")
73
- return answer
74
- logger.warning(f"Failed to extract answer from text: {text}")
75
- return ""
76
-
77
- def check_answer(self, question: str, choices: dict, num_inferences: int = 10) -> Tuple[str, str]:
78
- """
79
- Perform self-consistency check:
80
- - Run inference num_inferences times.
81
- - Extract answer each time.
82
- - Majority vote the final answer.
83
- """
84
-
85
- prompt = self._create_prompt(question, choices) # ์ˆ˜์ •๋œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
86
- answer_counts = {"A": 0, "B": 0, "C": 0, "D": 0}
87
-
88
- inputs = self.tokenizer(prompt, return_tensors='pt')
89
- if torch.cuda.is_available():
90
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
91
-
92
- for _ in range(num_inferences):
93
- with torch.no_grad():
94
- outputs = self.model.generate(
95
- **inputs,
96
- max_new_tokens=50,
97
- num_return_sequences=1,
98
- temperature=0.7,
99
- top_p=0.9,
100
- do_sample=True,
101
- eos_token_id=self.tokenizer.eos_token_id
102
- )
103
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
104
- predicted_answer = self._extract_answer(generated_text)
105
-
106
- logger.info(f"Generated text: {generated_text}") # ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ํ…์ŠคํŠธ ํ™•์ธ
107
- logger.info(f"Predicted answer: {predicted_answer}") # ์ถ”์ถœ๋œ ์ •๋‹ต ํ™•์ธ
108
-
109
- if predicted_answer in answer_counts:
110
- answer_counts[predicted_answer] += 1
111
- else:
112
- logger.warning(f"Invalid answer extracted: {predicted_answer}")
113
-
114
- # Majority vote
115
- final_answer = max(answer_counts, key=answer_counts.get)
116
- explanation = f"Answer counts: {answer_counts}. Majority answer: {final_answer}"
117
-
118
- logger.info(f"Answer counts: {answer_counts}")
119
- logger.info(f"Final Answer: {final_answer}")
120
-
121
- return final_answer, explanation
 
1
  # module3.py
2
+ import requests
3
+ from typing import Optional
 
4
  import logging
5
+ from dotenv import load_dotenv
6
+ import os
7
 
8
+ # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ # .env ํŒŒ์ผ ๋กœ๋“œ
13
+ load_dotenv()
14
+
15
+ # Hugging Face API ์ •๋ณด
16
+ API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
17
+ API_KEY = os.getenv("HUGGINGFACE_API_KEY")
18
+
19
+ if not API_KEY:
20
+ raise ValueError("API_KEY๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. .env ํŒŒ์ผ์„ ํ™•์ธํ•˜์„ธ์š”.")
21
+
22
+ class AnswerVerifier:
23
+ def verify_answer(self, question: str, choices: dict) -> Optional[str]:
24
+ """์ฃผ์–ด์ง„ ๋ฌธ์ œ์™€ ๋ณด๊ธฐ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ •๋‹ต์„ ๊ฒ€์ฆ"""
25
+ try:
26
+ prompt = self._create_prompt(question, choices)
27
+ headers = {"Authorization": f"Bearer {API_KEY}"}
28
+
29
+ response = requests.post(
30
+ API_URL,
31
+ headers=headers,
32
+ json={"inputs": prompt}
33
+ )
34
+ response.raise_for_status()
35
+
36
+ response_data = response.json()
37
+ logger.debug(f"Raw API response: {response_data}")
38
+
39
+ # API ์‘๋‹ต ์ฒ˜๋ฆฌ
40
+ generated_text = ""
41
+ if isinstance(response_data, list):
42
+ if response_data and isinstance(response_data[0], dict):
43
+ generated_text = response_data[0].get('generated_text', '')
44
+ else:
45
+ generated_text = response_data[0] if response_data else ''
46
+ elif isinstance(response_data, dict):
47
+ generated_text = response_data.get('generated_text', '')
48
+ else:
49
+ generated_text = str(response_data)
50
+
51
+ verified_answer = self._extract_answer(generated_text)
52
+ logger.info(f"Verified answer: {verified_answer}")
53
+ return verified_answer
54
 
55
+ except Exception as e:
56
+ logger.error(f"Error in verify_answer: {e}")
57
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def _create_prompt(self, question: str, choices: dict) -> str:
60
+ """๊ฒ€์ฆ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
61
+ return f"""
 
 
62
  <|begin_of_text|>
63
  <|start_header_id|>system<|end_header_id|>
64
+ You are an expert mathematics teacher checking student answers.
65
+ Please analyze the following question and select the single best answer.
66
+ Output ONLY the letter of the correct answer (A, B, C, or D) without any explanation.
 
 
 
 
 
67
  <|eot_id|>
68
  <|start_header_id|>user<|end_header_id|>
69
  Question: {question}
70
 
 
71
  A) {choices['A']}
72
  B) {choices['B']}
73
  C) {choices['C']}
74
  D) {choices['D']}
75
 
76
+ Select the correct answer letter (A, B, C, or D):
77
  <|eot_id|>
78
  <|start_header_id|>assistant<|end_header_id|>
79
+ """.strip()
80
+
81
+ def _extract_answer(self, response: str) -> Optional[str]:
82
+ """์‘๋‹ต์—์„œ A, B, C, D ์ค‘ ํ•˜๋‚˜๋ฅผ ์ถ”์ถœ"""
83
+ response = response.strip().upper()
84
+ valid_answers = {'A', 'B', 'C', 'D'}
85
+
86
+ # ์‘๋‹ต์—์„œ ์œ ํšจํ•œ ๋‹ต์•ˆ ์ฐพ๊ธฐ
87
+ for answer in valid_answers:
88
+ if answer in response:
89
+ return answer
90
+
91
+ return None