Skip to content

[Scripts] Add dense baseline evaluation scripts. #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 57 additions & 72 deletions eval/reasoning_tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def parse_args():
parser.add_argument('--n_sampling', type=int, default=1, help="n for sampling")
parser.add_argument('--batch_size', type=int, default=16, help="batch_size")
parser.add_argument('--limit', type=int, default=-1, help="limit")
parser.add_argument('--repeat', type=int, default=1, help="repeat")
parser.add_argument("--k", type=int, default=1, help="Value of k for pass@k calculation")
parser.add_argument("--data_dir", default="./data", type=str)
parser.add_argument('--data_name', type=str, default="math", help='identify how to extract answer')
Expand All @@ -100,8 +101,8 @@ def parse_args():
parser.add_argument("--threshold", default=0, type=float)
parser.add_argument("--block_size", default=64, type=int)
parser.add_argument("--rank", default=0, type=int)
parser.add_argument("--attention_implementation", default="seer_sparse", choices=["seer_sparse", "oracle_sparse", "fa2", "sdpa"], type=str)
parser.add_argument("--use_batch_exist", action="store_true")
parser.add_argument("--attention_implementation", default="seer_sparse", choices=["seer_sparse", "seer_dense", "oracle_sparse", "fa2", "sdpa"], type=str)
parser.add_argument("--use_batch_exist", default=True, type=bool)
parser.add_argument("--use_fused_kernel", action="store_true")
parser.add_argument("--profile_sparsity", action="store_true")
args = parser.parse_args()
Expand Down Expand Up @@ -146,6 +147,8 @@ def get_three_prompt(prompt_type, data_name):
def infer(args):
model_name_or_path = args.model_name_or_path
print(f"current eval model: {model_name_or_path}")
device = f"cuda:{args.rank}"

generate_lens = []
prompt_lens = []

Expand All @@ -159,6 +162,7 @@ def infer(args):
limit = args.limit
if limit > 0:
examples = examples[:limit]
examples = examples * args.repeat

if args.profile_sparsity:
assert args.attention_implementation in ["seer_sparse", "oracle_sparse"], "profile_sparsity only support seer_sparse and oracle_sparse"
Expand Down Expand Up @@ -190,7 +194,7 @@ def infer(args):
else:
cur_prompt = question_format.format(question=question)
if args.surround_with_messages:
if args.data_name in ["aime", "math"]:
if args.data_name in ["aime", "math", "olympiadbench"]:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": cur_prompt}
Expand All @@ -204,15 +208,15 @@ def infer(args):
prompt_batch.append(cur_prompt)


if args.attention_implementation == "seer_sparse" or args.attention_implementation == "oracle_sparse":
if args.attention_implementation == "seer_sparse" or args.attention_implementation == "oracle_sparse" or args.attention_implementation == "seer_dense":
model = SeerDecodingQwen2ForCausalLM.from_pretrained(model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map=device,
load_gate = args.attention_implementation == "seer_sparse",
use_cache=True,
seerattn_threshold=args.threshold,
seerattn_gate_block_size=args.block_size,
seerattn_use_oracle_sparse = args.attention_implementation == "oracle_sparse",
seerattn_implementation = args.attention_implementation,
use_flash_rope=args.use_fused_kernel,
fused_norm=args.use_fused_kernel,
seerattn_output_sparsity=args.profile_sparsity,
Expand All @@ -221,14 +225,14 @@ def infer(args):
elif args.attention_implementation == "fa2":
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map=device,
use_cache=True,
attn_implementation="flash_attention_2",
trust_remote_code=True)
elif args.attention_implementation == "sdpa":
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map=f"cuda:{args.rank}",
use_cache=True,
trust_remote_code=True)
else:
Expand All @@ -239,13 +243,15 @@ def infer(args):

generate_lens = []
correct_cnt = 0
output_subdir = f"{args.data_name}_bs_{args.batch_size}_attn_{args.attention_implementation}_T{args.threshold}_blocksize{args.block_size}_batch_exist_{args.use_batch_exist}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checkpoint_filename = f"ckpt.jsonl"
args.output_dir = os.path.join(args.output_dir, output_subdir)
os.makedirs(args.output_dir, exist_ok=True)
output_path_txt = os.path.join(args.output_dir, "summary.txt")
output_completions_path = os.path.join(args.output_dir, "completions.json")
checkpoint_filename_json = os.path.join(args.output_dir, checkpoint_filename)
output_filename = f"{args.data_name}_bs{args.batch_size}_{args.attention_implementation}_T{args.threshold}_blocksize{args.block_size}_batchexist{args.use_batch_exist}.txt"
output_dir = os.path.join(args.output_dir, f"rank{args.rank}")
os.makedirs(output_dir, exist_ok=True)
print("make output dir: ", output_dir)
output_path_txt = os.path.join(output_dir, output_filename)
completion_filename = output_filename[:-4] + "_completions.json"
sparsity_filename = output_filename[:-4] + "_sparsity_info.json"
ckpt_filename = output_filename[:-4] + "_ckpt.json"

completions = []
batch_size = args.batch_size

Expand All @@ -257,7 +263,7 @@ def infer(args):
# Tokenize the prompt batch
print("start batch: ", i, flush=True)
batch_prompts = prompt_batch[i:i+batch_size]
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=True).to('cuda')
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=True).to(device)
batch_input_ids = tokenized_prompts.input_ids
attention_mask = tokenized_prompts.attention_mask

Expand Down Expand Up @@ -291,80 +297,59 @@ def infer(args):

print("get output in batch: ", i, flush=True)

all_batch_sparsitys_info.append(batch_sparsitys_info)
if args.profile_sparsity:
all_batch_sparsitys_info.append(batch_sparsitys_info)

for j in range(len(outputs)):
output_seq = outputs[j]
num_tokens = (output_seq != tokenizer.pad_token_id).sum().item()
generate_lens.append(num_tokens - len(batch_input_ids[j]))
output_tokens = (output_seq != tokenizer.pad_token_id).sum().item()
prompt_tokens = (batch_input_ids[j] != tokenizer.pad_token_id).sum().item()
generate_lens.append(output_tokens - prompt_tokens)

batch_results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
completions.extend(batch_results)
print("finish batch: ", i, flush=True)

if args.profile_sparsity:
total_activate_count, total_original_count, overall_sparsity_ratio = calculate_overall_sparsity(all_batch_sparsitys_info)
print("total_activate_count: ", total_activate_count)
print("total_original_count: ", total_original_count)
print("overall_sparsity: ", overall_sparsity_ratio)

# check all the correct
for i in range(len(prompt_batch)):
d = examples[i]
gt_cot, gt_ans = parse_ground_truth(d, args.data_name)
generated_responses = [completions[i]]
generated_answers = [extract_answer(generated_response, args.data_name) for generated_response in generated_responses]
is_correct_list = [check_is_correct(generated_answer, gt_ans) for generated_answer in generated_answers]
is_correct = any(is_correct_list)
if is_correct:
correct_cnt += 1



end = time.time()
total_time = end - begin
print("llm generate done")
if os.path.exists(checkpoint_filename_json):
os.remove(checkpoint_filename_json)

print("generate_lens: ", generate_lens)
if args.profile_sparsity:
total_activate_count, total_original_count, overall_sparsity_ratio = calculate_overall_sparsity(all_batch_sparsitys_info)
print("Overall_sparsity: ", overall_sparsity_ratio)


print(f"correct cnt / total cnt: {correct_cnt}/{len(examples)}")
print(f"Acc: {correct_cnt / len(examples):.4f}")

with open(os.path.join(output_dir, completion_filename), 'w') as f:
json.dump(completions, f)

# generate_len
average_generate_len = sum(generate_lens) / len(generate_lens)
max_generate_len = max(generate_lens)
print(f"Max generate length: {max_generate_len}")
print(f"Average generate length: {average_generate_len}")
if args.profile_sparsity:
sparsity_info = {"sparsity_info": all_batch_sparsitys_info}
with open(os.path.join(output_dir, sparsity_filename), 'w') as f:
json.dump(sparsity_info, f)

end = time.time()
total_time = end - begin
average_time_per_token = total_time / sum(generate_lens)
print(f"Total time: {total_time}s")
print(f"Average time per token: {average_time_per_token}")
if args.profile_sparsity:
checkpoint_data = {
"output_path_txt": output_path_txt,
"generate_lens": generate_lens,
"total_time": total_time,
"overall_sparsity": overall_sparsity_ratio,
}
else:
checkpoint_data = {
"output_path_txt": output_path_txt,
"generate_lens": generate_lens,
"total_time": total_time,
}

with open(os.path.join(output_dir, ckpt_filename), 'w') as f:
json.dump(checkpoint_data, f)

with open(output_path_txt, "a") as f:
f.write(f"Acc: {correct_cnt / len(examples):.4f}\n")
f.write(f"Average generate length: {average_generate_len}\n")
f.write(f"Max generate length: {max_generate_len}\n")
f.write(f"Total time: {total_time/60:.2f}min\n")
f.write(f"Average time per token: {average_time_per_token}\n")
if args.profile_sparsity:
f.write(f"Total activate count: {total_activate_count}\n")
f.write(f"Total original count: {total_original_count}\n")
f.write(f"Overall sparsity: {overall_sparsity_ratio}\n")
f.write("\n")
print("Successfully saved!")

print("Results saved to ", output_path_txt)


# Save completions to json
with open(output_completions_path, "w") as f:
json.dump(completions, f)


if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_args()
set_seed(args.seed)
infer(args)
40 changes: 40 additions & 0 deletions eval/reasoning_tasks/eval_dense.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
model_dir="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
output_dir="./result_dense"
task=olympiadbench # aime math gpqa olympiadbench

bs=60
limit=-1
repeat=1

use_batch_exist=1
attention_implementation=seer_dense # fa2 seer_sparse seer_dense

for gpu in 0 1 2 3
do
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python eval.py \
--model_name_or_path $model_dir \
--data_name $task \
--batch_size $bs \
--limit $limit \
--repeat $repeat \
--output_dir $output_dir \
--attention_implementation $attention_implementation \
--use_batch_exist $use_batch_exist \
--surround_with_messages \
--rank $gpu &
done
wait

echo "All generation finished"

for gpu in 0 1 2 3
do
python get_results.py \
--data_name $task \
--limit $limit \
--repeat $repeat \
--output_dir ${output_dir}/rank${rank} \
done

echo "All finished"
10 changes: 8 additions & 2 deletions eval/reasoning_tasks/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def batch_exist_generate(
# Initialize variables
generation_config, model_kwargs = model._prepare_generation_config(None)
generated = input_ids
eos_token_id = generation_config.eos_token_id
if isinstance(generation_config.eos_token_id, list):
eos_token_id = generation_config.eos_token_id[0]
else:
eos_token_id = generation_config.eos_token_id
initial_batch_size = input_ids.shape[0]

device = input_ids.device
Expand Down Expand Up @@ -70,7 +73,10 @@ def batch_exist_generate(


# Update finished flags for the active sequences.
finished[cur_to_orig] |= (next_tokens.squeeze(1) == eos_token_id)
if isinstance(generation_config.eos_token_id, list):
finished[cur_to_orig] |= (next_tokens.squeeze(1) in generation_config.eos_token_id)
else:
finished[cur_to_orig] |= (next_tokens.squeeze(1) == eos_token_id)

# If all sequences are finished, break.
if finished.all():
Expand Down
Loading