Skip to content

Commit 2db2b1d

Browse files
committed
feat: add C-Eval
1 parent 5d6710f commit 2db2b1d

File tree

6 files changed

+473
-0
lines changed

6 files changed

+473
-0
lines changed

eval/benchs/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_evaluator import BaseEvaluator
2+
from .ceval.eval_ceval import CEvalEvaluator
23
from .exampleqa.eval_exampleqa import ExampleQAEvaluator
34
from .halluqa.eval_halluqa_mc import HalluQAMCEvaluator
45
from .halueval.eval_halueval_dialog import HaluEvalDialogEvaluator
@@ -11,6 +12,8 @@
1112

1213
# ! Register all evaluators here in alphabetical order.
1314
__all__ = [
15+
# CEval
16+
"CEvalEvaluator",
1417
# ExampleQA
1518
"ExampleQAEvaluator",
1619
# HalluQA

eval/benchs/ceval/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# C-Eval
2+
3+
## Information
4+
5+
- **Paper**: C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models
6+
- **Institution**:
7+
- Shanghai Jiao Tong University
8+
- Tsinghua University
9+
- University of Edinburgh
10+
- Hong Kong University of Science and Technology
11+
- **arXiv**: https://arxiv.org/abs/2305.08322
12+
- **GitHub**: https://github.com/hkust-nlp/ceval
13+
- **Website**: https://cevalbenchmark.com/
14+
15+
## Evaluators
16+
17+
| Evaluator | Metric | Description |
18+
| ---------------- | -------- | ----------------- |
19+
| `CEvalEvaluator` | Accuracy | Multi-choice task |
20+
21+
## Note
22+
23+
Make sure you can **access Hugging Face** so that the dataset can be downloaded.
24+
25+
## Citation
26+
27+
```bibtex
28+
@inproceedings{huang2023ceval,
29+
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
30+
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
31+
booktitle={Advances in Neural Information Processing Systems},
32+
year={2023}
33+
}
34+
```

eval/benchs/ceval/dataset.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from collections import defaultdict
2+
from typing import Literal
3+
4+
from datasets import load_dataset
5+
from tqdm import tqdm
6+
7+
from ..base_dataset import BaseDataset
8+
from .utils import get_subject_mapping
9+
10+
11+
class CEvalDataset(BaseDataset):
12+
def __init__(
13+
self, disciplines: set[str] = None, split: Literal["test", "val", "dev"] = "val"
14+
):
15+
"""
16+
Args:
17+
disciplines: Disciplines to load. If None, all disciplines will be loaded.
18+
split: The split to load. One of "test", "val", "dev".
19+
"""
20+
subject_mapping = get_subject_mapping()
21+
self.data = []
22+
if disciplines is None:
23+
disciplines = set(subject_mapping.keys())
24+
25+
for discipline in tqdm(disciplines, desc=f"Loading CEval > {split}"):
26+
ds = load_dataset("ceval/ceval-exam", discipline, split=split)
27+
for item in ds:
28+
item["id"] = f"{discipline}_{split}_{item['id']:>04}"
29+
item["type"] = discipline
30+
self.data.append(item)
31+
32+
def load(self) -> list[dict]:
33+
return self.data
34+
35+
def load_as_dict_of_discipline(self, num_shots: int) -> dict[str, list[dict]]:
36+
examples = defaultdict(list)
37+
for item in self.data:
38+
if len(examples[item["type"]]) < num_shots:
39+
examples[item["type"]].append(item)
40+
return examples

eval/benchs/ceval/eval_ceval.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Literal
2+
3+
from ...llms import BaseLLM
4+
from ..base_evaluator import BaseEvaluator
5+
from .dataset import CEvalDataset
6+
from .utils import get_subject_mapping
7+
8+
QA_TEMPLATE = """
9+
{question}
10+
A. {choice_a}
11+
B. {choice_b}
12+
C. {choice_c}
13+
D. {choice_d}
14+
答案:{answer}
15+
"""
16+
17+
PROMPT_TEMPLATE = """以下是中国关于{discipline}考试的单项选择题,请选出其中的正确答案。
18+
{qa_examples}
19+
{qa_test}"""
20+
21+
22+
CEVAL_HARD_DISCIPLINES = ",".join(
23+
[
24+
"advanced_mathematics",
25+
"discrete_mathematics",
26+
"probability_and_statistics",
27+
"college_chemistry",
28+
"college_physics",
29+
"high_school_mathematics",
30+
"high_school_chemistry",
31+
"high_school_physics",
32+
]
33+
)
34+
35+
36+
class CEvalEvaluator(BaseEvaluator):
37+
38+
def __init__(
39+
self,
40+
model: BaseLLM,
41+
num_batches: int = 1,
42+
output_dir: str = "./output",
43+
disciplines: str = CEVAL_HARD_DISCIPLINES,
44+
split: Literal["test", "val", "dev"] = "val",
45+
num_shots: int = 2,
46+
):
47+
super().__init__(
48+
model,
49+
num_batches,
50+
output_dir,
51+
disciplines=disciplines,
52+
split=split,
53+
num_shots=num_shots,
54+
)
55+
56+
self.split = split
57+
58+
# ─── Get Valid Disciplines ────────────────────────────────────
59+
60+
self.all_disciplines = set(get_subject_mapping().keys())
61+
if disciplines is None:
62+
self.disciplines = self.all_disciplines
63+
else:
64+
self.disciplines = set(disciplines.split(",")) & self.all_disciplines
65+
66+
# ─── Load Examples For Few-shot Learning ──────────────────────
67+
68+
if num_shots > 0:
69+
ds = CEvalDataset(self.disciplines, split="dev")
70+
self.discipline_examples = ds.load_as_dict_of_discipline(num_shots)
71+
else:
72+
self.discipline_examples = {}
73+
74+
def set_generation_configs(self) -> None:
75+
new_configs = {"max_new_tokens": 16, "do_sample": False}
76+
self.model.update_generation_configs(new_configs)
77+
78+
def load_batched_dataset(self) -> list[list[dict]]:
79+
dataset = CEvalDataset(self.disciplines, split=self.split)
80+
batches = dataset.to_batched(self.num_batches)
81+
return batches
82+
83+
def qa_prompt(self, examples: list[dict]) -> str:
84+
prompt = "".join(
85+
QA_TEMPLATE.format(
86+
question=example["question"],
87+
choice_a=example["A"],
88+
choice_b=example["B"],
89+
choice_c=example["C"],
90+
choice_d=example["D"],
91+
answer=example["answer"],
92+
)
93+
for example in examples
94+
)
95+
return prompt
96+
97+
def scoring(self, data_point: dict) -> dict:
98+
discipline = data_point["type"]
99+
query = PROMPT_TEMPLATE.format(
100+
discipline=get_subject_mapping()[discipline][1], # Get the Chinese name
101+
qa_examples=self.qa_prompt(self.discipline_examples[discipline]),
102+
qa_test=self.qa_prompt([data_point]),
103+
)
104+
query = query.strip()[:-1] # Remove the answer to be predicted
105+
response = self.model.safe_request(query)
106+
answer = response.strip().split("\n")[0].strip() # Get the first line
107+
return {
108+
"metrics": {
109+
"correct": answer == data_point["answer"],
110+
},
111+
"log": {
112+
"answer": answer,
113+
"response": response,
114+
"query": query,
115+
},
116+
"valid": answer != "",
117+
}
118+
119+
def compute_overall(self, results: list[dict]) -> dict:
120+
return {
121+
"accuracy": sum([result["metrics"]["correct"] for result in results])
122+
/ len(results),
123+
"num": len(results),
124+
}

0 commit comments

Comments
 (0)