Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Deprecate val_batch_size #353

Merged
merged 3 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Data
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
Expand All @@ -39,8 +38,6 @@ Data
algorithms (e.g. PPO) generates up to this length
- ``data.train_batch_size``: Batch size sampled for one training
iteration of different RL algorithms.
- ``data.val_batch_size``: Batch size sampled for one validation
iteration.
- ``data.return_raw_input_ids``: Whether to return the original
input_ids without adding chat template. This is mainly used to
accommodate situations where the reward model's chat template differs
Expand Down
1 change: 0 additions & 1 deletion docs/examples/gsm8k_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ The script of run_deepseek7b_llm.sh
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion docs/start/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.pa
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek7b_llm_sp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=512 \
data.val_batch_size=128 \
data.max_prompt_length=128 \
data.max_response_length=128 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_deepseek_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=$HOME/models/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=512 \
data.val_batch_size=1312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=google/gemma-2-2b-it \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=4096 \
data.val_batch_size=1312 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=4096 \
data.val_batch_size=1312 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_trainer/run_qwen2.5-32b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6304 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
Expand Down
14 changes: 7 additions & 7 deletions examples/ppo_trainer/verl_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,16 @@
"source": [
"import torch\n",
"try:\n",
" assert torch.cuda.is_available() is True\n",
" torch.ones(1, dtype=torch.bfloat16).cuda()\n",
" assert torch.cuda.is_available() is True\n",
" torch.ones(1, dtype=torch.bfloat16).cuda()\n",
"except AssertionError:\n",
" print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
" print(\"Please switch to an env with GPUs supporting bfloat16 (L4 RTX 5000, A5000, A100, H100, A10, etc)\")\n",
"\n",
"try:\n",
" import verl\n",
" import verl\n",
"except Exception as e:\n",
" print(\"Please install verl via pip and restart the kernel\")\n",
" raise e\n",
" print(\"Please install verl via pip and restart the kernel\")\n",
" raise e\n",
"\n",
"import flash_attn"
]
Expand Down Expand Up @@ -561,6 +561,7 @@
"source": [
"import inspect\n",
"from verl.utils.reward_score.gsm8k import compute_score as gsm8k_reward\n",
"\n",
"print(inspect.getsource(gsm8k_reward))"
]
},
Expand Down Expand Up @@ -1103,7 +1104,6 @@
" data.train_files=$HOME/data/gsm8k/train.parquet \\\n",
" data.val_files=$HOME/data/gsm8k/test.parquet \\\n",
" data.train_batch_size=256 \\\n",
" data.val_batch_size=1312 \\\n",
" data.max_prompt_length=512 \\\n",
" data.max_response_length=256 \\\n",
" actor_rollout_ref.model.path=$HOME/models/Qwen2.5-0.5B-Instruct \\\n",
Expand Down
1 change: 0 additions & 1 deletion examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=512 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/train.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/rloo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
Expand Down
1 change: 0 additions & 1 deletion examples/slurm/ray_on_slurm.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \
data.train_files=$train_files \
data.val_files=$val_files \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
Expand Down
2 changes: 1 addition & 1 deletion examples/split_placement/config/ppo_trainer_split.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
Expand Down
1 change: 0 additions & 1 deletion examples/split_placement/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ python3 main_ppo_split.py \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_deepseek_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-1.3b-instruct \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_function_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_function_rm_grpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_function_rm_remax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_model_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_model_rm_liger_kernel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_qwen_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/run_ray_trainer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
data.train_batch_size=800 \
data.val_batch_size=200 \
data.max_prompt_length=16 \
data.max_response_length=32 \
data.return_raw_input_ids=True \
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: 1312
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
Expand Down
18 changes: 13 additions & 5 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.critic.model.use_remove_padding, \
"When using sequence parallelism for critic, you must enable `use_remove_padding`."

if config.data.get('val_batch_size', None) is not None:
print(
f"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves."
)

print("[validate_config] All configuration checks passed successfully!")

def _create_dataloader(self):
Expand Down Expand Up @@ -498,11 +503,14 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error')
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
# Validation datasets are sent to inference engines as a whole batch,
# which will schedule the memory themselves.
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=False,
collate_fn=collate_fn)

assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
Expand Down
Loading