From f16026e9bfb549c03baaa555be47f37a5f2af5bb Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 15:52:20 +0000 Subject: [PATCH 01/13] feat: Add multi-turn SFT dataset support - Add MultiTurnSFTDataset class for handling multi-turn conversations - Support different roles (system, user, assistant) with role-specific prefixes - Set loss mask to 1 for assistant responses only - Add comprehensive test suite for the new dataset class --- tests/soft/test_multiturn_sft_dataset.py | 86 ++++++++++++ verl/utils/dataset/multiturn_sft_dataset.py | 140 ++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 tests/soft/test_multiturn_sft_dataset.py create mode 100644 verl/utils/dataset/multiturn_sft_dataset.py diff --git a/tests/soft/test_multiturn_sft_dataset.py b/tests/soft/test_multiturn_sft_dataset.py new file mode 100644 index 00000000..602470ac --- /dev/null +++ b/tests/soft/test_multiturn_sft_dataset.py @@ -0,0 +1,86 @@ +""" +Test the MultiTurnSFTDataset implementation +""" +import os +import pandas as pd +import torch +from transformers import AutoTokenizer +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset + + +def test_multiturn_sft_dataset(): + # Create a temporary parquet file with test data + test_data = { + 'messages': [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And what is 4+4?"}, + {"role": "assistant", "content": "4+4 equals 8."} + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + {"role": "assistant", "content": "Why did the chicken cross the road?"}, + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "To get to the other side!"} + ] + ] + } + + # Create test directory if it doesn't exist + os.makedirs('test_data', exist_ok=True) + test_file = 'test_data/test.parquet' + + # Save test data to parquet + df = pd.DataFrame(test_data) + df.to_parquet(test_file) + + # Initialize tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained('gpt2') + dataset = MultiTurnSFTDataset( + parquet_files=test_file, + tokenizer=tokenizer, + max_length=512 + ) + + # Test dataset length + assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" + + # Get first item + item = dataset[0] + + # Check that all required keys are present + required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] + for key in required_keys: + assert key in item, f"Missing key {key} in dataset item" + assert isinstance(item[key], torch.Tensor), f"Expected torch.Tensor for {key}" + + # Verify loss mask shape matches input_ids + assert item['loss_mask'].shape == item['input_ids'].shape, \ + "Loss mask shape doesn't match input_ids shape" + + # Decode the tokens where loss_mask is 1 to verify they correspond to assistant messages + loss_mask = item['loss_mask'] + input_ids = item['input_ids'] + + # Get positions where loss_mask is 1 + assistant_positions = torch.where(loss_mask == 1)[0] + + # Verify that we have assistant positions with loss_mask=1 + assert len(assistant_positions) > 0, "No positions found with loss_mask=1" + + # Get all text from positions where loss_mask=1 + assistant_text = tokenizer.decode(input_ids[loss_mask == 1]) + print(f"Assistant text: {assistant_text}") + + # Verify it contains our expected assistant responses + assert any(x in assistant_text.lower() for x in ['equals', 'get to the other side']), \ + f"Expected assistant response content, got: {assistant_text}" + + print("All tests passed!") + + +if __name__ == "__main__": + test_multiturn_sft_dataset() \ No newline at end of file diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py new file mode 100644 index 00000000..00ccf1e3 --- /dev/null +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -0,0 +1,140 @@ +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +from typing import List, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask +from verl.utils import hf_tokenizer + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__(self, + parquet_files: Union[str, List[str]], + tokenizer, + messages_key='messages', # Key for the messages list in the parquet file + max_length=1024, + truncation='error'): + assert truncation in ['error', 'left', 'right'] + self.truncation = truncation + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + self.messages_key = messages_key + self.max_length = max_length + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + def series_to_item(ls): + import pandas, numpy + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + def __len__(self): + return len(self.messages) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + + # Process each message and concatenate with special tokens + all_tokens = [] + loss_mask_parts = [] + + for msg in messages: + # Add role prefix + if msg['role'] == 'system': + prefix = "<|system|>\n" + elif msg['role'] == 'user': + prefix = "<|user|>\n" + elif msg['role'] == 'assistant': + prefix = "<|assistant|>\n" + else: + raise ValueError(f"Unknown role: {msg['role']}") + + # Tokenize the message + msg_str = prefix + msg['content'] + "\n" + msg_tokens = tokenizer(msg_str, return_tensors='pt', add_special_tokens=False) + msg_ids = msg_tokens['input_ids'][0] + + # Create loss mask for this message (1 for assistant responses, 0 for others) + msg_mask = torch.zeros_like(msg_ids, dtype=torch.long) + if msg['role'] == 'assistant': + msg_mask = torch.ones_like(msg_ids, dtype=torch.long) + + all_tokens.append(msg_ids) + loss_mask_parts.append(msg_mask) + + # Concatenate all tokens and masks + input_ids = torch.cat(all_tokens) + loss_mask = torch.cat(loss_mask_parts) + attention_mask = torch.ones_like(input_ids) + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), + dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), + dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), + dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == 'left': + input_ids = input_ids[-self.max_length:] + attention_mask = attention_mask[-self.max_length:] + loss_mask = loss_mask[-self.max_length:] + elif self.truncation == 'right': + input_ids = input_ids[:self.max_length] + attention_mask = attention_mask[:self.max_length] + loss_mask = loss_mask[:self.max_length] + elif self.truncation == 'error': + raise ValueError(f'{sequence_length=} is larger than {self.max_length=}') + else: + raise ValueError(f'Unknown truncation method {self.truncation}') + + position_ids = compute_position_id_with_mask(attention_mask) + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'loss_mask': loss_mask + } \ No newline at end of file From 55cc8df83f4ba5700d1c28436bfa956741c1e5f6 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 17:57:18 +0000 Subject: [PATCH 02/13] fix: Use proper chat template for multi-turn dataset - Replace custom chat formatting with HuggingFace chat template - Use Qwen tokenizer for testing - Fix tensor indexing and loss mask generation - Update test to verify proper tokenization --- tests/soft/test_multiturn_sft_dataset.py | 2 +- verl/utils/dataset/multiturn_sft_dataset.py | 44 ++++++++------------- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/tests/soft/test_multiturn_sft_dataset.py b/tests/soft/test_multiturn_sft_dataset.py index 602470ac..77a72c5f 100644 --- a/tests/soft/test_multiturn_sft_dataset.py +++ b/tests/soft/test_multiturn_sft_dataset.py @@ -38,7 +38,7 @@ def test_multiturn_sft_dataset(): df.to_parquet(test_file) # Initialize tokenizer and dataset - tokenizer = AutoTokenizer.from_pretrained('gpt2') + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') dataset = MultiTurnSFTDataset( parquet_files=test_file, tokenizer=tokenizer, diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 00ccf1e3..7ab9514c 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -68,38 +68,28 @@ def __getitem__(self, item): tokenizer = self.tokenizer messages = self.messages[item] - # Process each message and concatenate with special tokens - all_tokens = [] - loss_mask_parts = [] + # Use the tokenizer's chat template to format and tokenize the conversation + tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) + input_ids = tokens[0] # The output is already a tensor + attention_mask = torch.ones_like(input_ids) + + # Create loss mask by identifying assistant responses + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + # For each assistant message, find its position in the tokenized text + current_tokens = [] for msg in messages: - # Add role prefix - if msg['role'] == 'system': - prefix = "<|system|>\n" - elif msg['role'] == 'user': - prefix = "<|user|>\n" - elif msg['role'] == 'assistant': - prefix = "<|assistant|>\n" - else: - raise ValueError(f"Unknown role: {msg['role']}") + # Tokenize this message + msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True, return_tensors='pt', add_generation_prompt=False) + msg_ids = msg_tokens[0] - # Tokenize the message - msg_str = prefix + msg['content'] + "\n" - msg_tokens = tokenizer(msg_str, return_tensors='pt', add_special_tokens=False) - msg_ids = msg_tokens['input_ids'][0] - - # Create loss mask for this message (1 for assistant responses, 0 for others) - msg_mask = torch.zeros_like(msg_ids, dtype=torch.long) + # If this is an assistant message, mark its tokens in the loss mask if msg['role'] == 'assistant': - msg_mask = torch.ones_like(msg_ids, dtype=torch.long) + start_idx = len(torch.cat(current_tokens)) if current_tokens else 0 + end_idx = start_idx + len(msg_ids) + loss_mask[start_idx:end_idx] = 1 - all_tokens.append(msg_ids) - loss_mask_parts.append(msg_mask) - - # Concatenate all tokens and masks - input_ids = torch.cat(all_tokens) - loss_mask = torch.cat(loss_mask_parts) - attention_mask = torch.ones_like(input_ids) + current_tokens.append(msg_ids) # Handle sequence length sequence_length = input_ids.shape[0] From 62e11a3b52a730139134a6d416f23c1766ed2d39 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 18:19:19 +0000 Subject: [PATCH 03/13] fix: Use proper chat template and improve tests - Use HuggingFace chat template instead of custom formatting - Add comprehensive tests for loss mask behavior - Verify both assistant and non-assistant content - Add debug output for test failures --- tests/soft/test_multiturn_sft_dataset.py | 135 ++++++++++++++++---- verl/utils/dataset/multiturn_sft_dataset.py | 38 +++--- 2 files changed, 131 insertions(+), 42 deletions(-) diff --git a/tests/soft/test_multiturn_sft_dataset.py b/tests/soft/test_multiturn_sft_dataset.py index 77a72c5f..47602313 100644 --- a/tests/soft/test_multiturn_sft_dataset.py +++ b/tests/soft/test_multiturn_sft_dataset.py @@ -9,6 +9,7 @@ def test_multiturn_sft_dataset(): + print("Starting test...") # Create a temporary parquet file with test data test_data = { 'messages': [ @@ -45,42 +46,124 @@ def test_multiturn_sft_dataset(): max_length=512 ) - # Test dataset length + # Test 1: Dataset Length assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" - # Get first item - item = dataset[0] + # Get items for testing + item0 = dataset[0] # Math conversation + item1 = dataset[1] # Joke conversation - # Check that all required keys are present + # Test 2: Required Keys and Types required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] for key in required_keys: - assert key in item, f"Missing key {key} in dataset item" - assert isinstance(item[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert key in item0, f"Missing key {key} in dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" - # Verify loss mask shape matches input_ids - assert item['loss_mask'].shape == item['input_ids'].shape, \ + # Test 3: Shape Consistency + assert item0['loss_mask'].shape == item0['input_ids'].shape, \ "Loss mask shape doesn't match input_ids shape" + assert item0['attention_mask'].shape == item0['input_ids'].shape, \ + "Attention mask shape doesn't match input_ids shape" + assert item0['position_ids'].shape == item0['input_ids'].shape, \ + "Position IDs shape doesn't match input_ids shape" - # Decode the tokens where loss_mask is 1 to verify they correspond to assistant messages - loss_mask = item['loss_mask'] - input_ids = item['input_ids'] + # Test 4: Loss Mask Pattern - Math Conversation + loss_mask0 = item0['loss_mask'] + input_ids0 = item0['input_ids'] - # Get positions where loss_mask is 1 - assistant_positions = torch.where(loss_mask == 1)[0] + # Find assistant response positions + assistant_positions0 = torch.where(loss_mask0 == 1)[0] + assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" - # Verify that we have assistant positions with loss_mask=1 - assert len(assistant_positions) > 0, "No positions found with loss_mask=1" + # Decode and verify assistant responses + assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"Math conversation assistant text: {assistant_text0}") + assert "2+2 equals 4" in assistant_text0, "First assistant response not found" + assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" - # Get all text from positions where loss_mask=1 - assistant_text = tokenizer.decode(input_ids[loss_mask == 1]) - print(f"Assistant text: {assistant_text}") + # Test 5: Loss Mask Pattern - Joke Conversation + loss_mask1 = item1['loss_mask'] + input_ids1 = item1['input_ids'] - # Verify it contains our expected assistant responses - assert any(x in assistant_text.lower() for x in ['equals', 'get to the other side']), \ - f"Expected assistant response content, got: {assistant_text}" + # Find assistant response positions + assistant_positions1 = torch.where(loss_mask1 == 1)[0] + assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" - print("All tests passed!") - - -if __name__ == "__main__": - test_multiturn_sft_dataset() \ No newline at end of file + # Decode and verify assistant responses + assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) + print(f"Joke conversation assistant text: {assistant_text1}") + assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert "other side" in assistant_text1, "Second assistant response not found" + + # Test 6: Attention Mask Pattern + attention_mask0 = item0['attention_mask'] + sequence_length = torch.sum(attention_mask0) + assert sequence_length > 0, "No tokens marked as attended in attention mask" + assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + if sequence_length < len(attention_mask0): + assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + + # Test 7: Position IDs Pattern + position_ids0 = item0['position_ids'] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ + "Position IDs not sequential for non-padded tokens" + if sequence_length < len(position_ids0): + assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + + # Test 8: Verify loss mask for assistant responses + # Get the full conversation text + full_text = tokenizer.decode(input_ids0) + print(f"\nFull conversation text:\n{full_text}") + + # Get the assistant responses + assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"\nAssistant responses (from loss mask):\n{assistant_text}") + + # Verify that loss mask is set for all assistant responses + for msg in test_data['messages'][0]: # First conversation + if msg['role'] == 'assistant': + # The content should appear in the masked text + assert msg['content'] in assistant_text, \ + f"Assistant message '{msg['content']}' not found in masked text" + + # The content should NOT appear in the non-masked text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + assert msg['content'] not in non_assistant_text, \ + f"Assistant message '{msg['content']}' found in non-assistant text" + + # Test 9: Verify non-assistant parts have loss_mask=0 + # Get non-assistant text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") + + # Verify that system and user messages are in the non-assistant text + for msg in test_data['messages'][0]: # First conversation + if msg['role'] in ['system', 'user']: + assert msg['content'] in non_assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + + # And verify they're NOT in the assistant text + assert msg['content'] not in assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + + # Test 10: Verify padding behavior + small_dataset = MultiTurnSFTDataset( + parquet_files=test_file, + tokenizer=tokenizer, + max_length=1024 # Larger than needed to test padding + ) + padded_item = small_dataset[0] + + # Get actual sequence length (before padding) + actual_length = torch.sum(padded_item['attention_mask']) + + # Verify padding tokens + assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ + "Padding tokens not set correctly" + assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ + "Attention mask not set correctly for padding" + assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ + "Loss mask not set correctly for padding" + + print("All tests passed!") \ No newline at end of file diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 7ab9514c..3a7a5b34 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -68,28 +68,31 @@ def __getitem__(self, item): tokenizer = self.tokenizer messages = self.messages[item] - # Use the tokenizer's chat template to format and tokenize the conversation - tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) - input_ids = tokens[0] # The output is already a tensor + # First, get the full conversation tokens + full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) + input_ids = full_tokens[0] # The output is already a tensor attention_mask = torch.ones_like(input_ids) # Create loss mask by identifying assistant responses loss_mask = torch.zeros_like(input_ids, dtype=torch.long) - # For each assistant message, find its position in the tokenized text - current_tokens = [] - for msg in messages: - # Tokenize this message - msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True, return_tensors='pt', add_generation_prompt=False) - msg_ids = msg_tokens[0] + # Process each message to find assistant responses + current_length = 0 + for i, msg in enumerate(messages): + # Get tokens for messages up to this point to find the start position + prefix_messages = messages[:i+1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) - # If this is an assistant message, mark its tokens in the loss mask - if msg['role'] == 'assistant': - start_idx = len(torch.cat(current_tokens)) if current_tokens else 0 - end_idx = start_idx + len(msg_ids) - loss_mask[start_idx:end_idx] = 1 + # Get tokens for messages up to previous point + prev_tokens = tokenizer.apply_chat_template(messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None + + # Calculate start and end positions + start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 + end_pos = prefix_tokens[0].shape[0] - current_tokens.append(msg_ids) + # If this is an assistant message, set loss mask + if msg['role'] == 'assistant': + loss_mask[start_pos:end_pos] = 1 # Handle sequence length sequence_length = input_ids.shape[0] @@ -120,7 +123,10 @@ def __getitem__(self, item): else: raise ValueError(f'Unknown truncation method {self.truncation}') - position_ids = compute_position_id_with_mask(attention_mask) + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask return { 'input_ids': input_ids, From 0dbd4dd6cd6bbb0b47d29d9f05e106d59fea4ce2 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 18:24:32 +0000 Subject: [PATCH 04/13] ci: Add unit tests workflow - Add separate workflow for unit tests - Run tests in tests/soft directory - Generate and upload coverage reports - Use same container as e2e tests --- .github/workflows/unit_tests.yml | 49 ++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/unit_tests.yml diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml new file mode 100644 index 00000000..48d30c85 --- /dev/null +++ b/.github/workflows/unit_tests.yml @@ -0,0 +1,49 @@ +name: Unit Tests + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/unit_tests.yml + - "tests/soft/*.py" + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/unit_tests.yml + - "tests/soft/*.py" + +jobs: + unit_tests: + runs-on: ubuntu-latest + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install dependencies + run: | + pip3 install -e .[test] + pip3 install pytest pytest-cov + + - name: Run unit tests + run: | + pytest tests/soft/ -v --cov=verl --cov-report=xml + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + files: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: true \ No newline at end of file From 60e78628b044ffa1e91fd898c1fec2b811ab8da8 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 18:27:59 +0000 Subject: [PATCH 05/13] refactor: Move unit tests to tests/sft/unit - Move tests from tests/soft to tests/sft/unit for consistency - Update CI workflow paths - Keep all SFT-related tests under tests/sft --- .github/workflows/unit_tests.yml | 6 +- tests/sft/unit/test_multiturn_sft_dataset.py | 169 +++++++++++++++++++ 2 files changed, 172 insertions(+), 3 deletions(-) create mode 100644 tests/sft/unit/test_multiturn_sft_dataset.py diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 48d30c85..d7585393 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -9,14 +9,14 @@ on: paths: - "**/*.py" - .github/workflows/unit_tests.yml - - "tests/soft/*.py" + - "tests/sft/unit/*.py" pull_request: branches: - main paths: - "**/*.py" - .github/workflows/unit_tests.yml - - "tests/soft/*.py" + - "tests/sft/unit/*.py" jobs: unit_tests: @@ -36,7 +36,7 @@ jobs: - name: Run unit tests run: | - pytest tests/soft/ -v --cov=verl --cov-report=xml + pytest tests/sft/unit/ -v --cov=verl --cov-report=xml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/tests/sft/unit/test_multiturn_sft_dataset.py b/tests/sft/unit/test_multiturn_sft_dataset.py new file mode 100644 index 00000000..47602313 --- /dev/null +++ b/tests/sft/unit/test_multiturn_sft_dataset.py @@ -0,0 +1,169 @@ +""" +Test the MultiTurnSFTDataset implementation +""" +import os +import pandas as pd +import torch +from transformers import AutoTokenizer +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset + + +def test_multiturn_sft_dataset(): + print("Starting test...") + # Create a temporary parquet file with test data + test_data = { + 'messages': [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + {"role": "user", "content": "And what is 4+4?"}, + {"role": "assistant", "content": "4+4 equals 8."} + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + {"role": "assistant", "content": "Why did the chicken cross the road?"}, + {"role": "user", "content": "Why?"}, + {"role": "assistant", "content": "To get to the other side!"} + ] + ] + } + + # Create test directory if it doesn't exist + os.makedirs('test_data', exist_ok=True) + test_file = 'test_data/test.parquet' + + # Save test data to parquet + df = pd.DataFrame(test_data) + df.to_parquet(test_file) + + # Initialize tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') + dataset = MultiTurnSFTDataset( + parquet_files=test_file, + tokenizer=tokenizer, + max_length=512 + ) + + # Test 1: Dataset Length + assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" + + # Get items for testing + item0 = dataset[0] # Math conversation + item1 = dataset[1] # Joke conversation + + # Test 2: Required Keys and Types + required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] + for key in required_keys: + assert key in item0, f"Missing key {key} in dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" + + # Test 3: Shape Consistency + assert item0['loss_mask'].shape == item0['input_ids'].shape, \ + "Loss mask shape doesn't match input_ids shape" + assert item0['attention_mask'].shape == item0['input_ids'].shape, \ + "Attention mask shape doesn't match input_ids shape" + assert item0['position_ids'].shape == item0['input_ids'].shape, \ + "Position IDs shape doesn't match input_ids shape" + + # Test 4: Loss Mask Pattern - Math Conversation + loss_mask0 = item0['loss_mask'] + input_ids0 = item0['input_ids'] + + # Find assistant response positions + assistant_positions0 = torch.where(loss_mask0 == 1)[0] + assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"Math conversation assistant text: {assistant_text0}") + assert "2+2 equals 4" in assistant_text0, "First assistant response not found" + assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" + + # Test 5: Loss Mask Pattern - Joke Conversation + loss_mask1 = item1['loss_mask'] + input_ids1 = item1['input_ids'] + + # Find assistant response positions + assistant_positions1 = torch.where(loss_mask1 == 1)[0] + assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) + print(f"Joke conversation assistant text: {assistant_text1}") + assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert "other side" in assistant_text1, "Second assistant response not found" + + # Test 6: Attention Mask Pattern + attention_mask0 = item0['attention_mask'] + sequence_length = torch.sum(attention_mask0) + assert sequence_length > 0, "No tokens marked as attended in attention mask" + assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + if sequence_length < len(attention_mask0): + assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + + # Test 7: Position IDs Pattern + position_ids0 = item0['position_ids'] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ + "Position IDs not sequential for non-padded tokens" + if sequence_length < len(position_ids0): + assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + + # Test 8: Verify loss mask for assistant responses + # Get the full conversation text + full_text = tokenizer.decode(input_ids0) + print(f"\nFull conversation text:\n{full_text}") + + # Get the assistant responses + assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"\nAssistant responses (from loss mask):\n{assistant_text}") + + # Verify that loss mask is set for all assistant responses + for msg in test_data['messages'][0]: # First conversation + if msg['role'] == 'assistant': + # The content should appear in the masked text + assert msg['content'] in assistant_text, \ + f"Assistant message '{msg['content']}' not found in masked text" + + # The content should NOT appear in the non-masked text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + assert msg['content'] not in non_assistant_text, \ + f"Assistant message '{msg['content']}' found in non-assistant text" + + # Test 9: Verify non-assistant parts have loss_mask=0 + # Get non-assistant text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") + + # Verify that system and user messages are in the non-assistant text + for msg in test_data['messages'][0]: # First conversation + if msg['role'] in ['system', 'user']: + assert msg['content'] in non_assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + + # And verify they're NOT in the assistant text + assert msg['content'] not in assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + + # Test 10: Verify padding behavior + small_dataset = MultiTurnSFTDataset( + parquet_files=test_file, + tokenizer=tokenizer, + max_length=1024 # Larger than needed to test padding + ) + padded_item = small_dataset[0] + + # Get actual sequence length (before padding) + actual_length = torch.sum(padded_item['attention_mask']) + + # Verify padding tokens + assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ + "Padding tokens not set correctly" + assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ + "Attention mask not set correctly for padding" + assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ + "Loss mask not set correctly for padding" + + print("All tests passed!") \ No newline at end of file From 3c3be7a286d1fca9c0c63c858b93c0b96be8fa27 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 19:18:15 +0000 Subject: [PATCH 06/13] feat: Add multi-turn training support - Update trainer to support both single-turn and multi-turn datasets - Add example script for multi-turn training - Add data preprocessing script for multi-turn conversations - Use proper chat template for multi-turn data --- examples/data_preprocess/multiturn.py | 74 +++++++++++++++++++++++ examples/sft/multiturn/run_qwen_05_sp2.sh | 30 +++++++++ verl/trainer/fsdp_sft_trainer.py | 57 ++++++++++++----- 3 files changed, 145 insertions(+), 16 deletions(-) create mode 100644 examples/data_preprocess/multiturn.py create mode 100755 examples/sft/multiturn/run_qwen_05_sp2.sh diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py new file mode 100644 index 00000000..135ad514 --- /dev/null +++ b/examples/data_preprocess/multiturn.py @@ -0,0 +1,74 @@ +""" +Example script for preprocessing multi-turn conversation data into parquet format +""" + +import os +import argparse +import datasets +from verl.utils.hdfs_io import copy, makedirs + + +def process_conversation(example, idx, split): + """Convert a conversation into the expected format""" + messages = [] + + # Add system message if present + if example.get('system', ''): + messages.append({ + "role": "system", + "content": example['system'] + }) + + # Add conversation turns + for turn in example['conversation']: + messages.append({ + "role": turn['role'], + "content": turn['content'] + }) + + # Return the processed data + return { + "data_source": "multiturn_example", + "messages": messages, + "extra_info": { + 'split': split, + 'index': idx, + 'original': example + } + } + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--local_dir', default='~/data/multiturn') + parser.add_argument('--hdfs_dir', default=None) + + args = parser.parse_args() + + # Load your dataset here + # This is just an example - replace with your actual data loading + dataset = datasets.load_dataset('your_dataset_name') + + train_dataset = dataset['train'] + test_dataset = dataset['test'] + + # Process the datasets + train_dataset = train_dataset.map( + function=lambda x, i: process_conversation(x, i, 'train'), + with_indices=True + ) + test_dataset = test_dataset.map( + function=lambda x, i: process_conversation(x, i, 'test'), + with_indices=True + ) + + # Save to parquet files + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) + test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) \ No newline at end of file diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh new file mode 100755 index 00000000..48d2a2fb --- /dev/null +++ b/examples/sft/multiturn/run_qwen_05_sp2.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/multiturn/train.parquet \ + data.val_files=$HOME/data/multiturn/test.parquet \ + data.use_multiturn=true \ + data.messages_key=messages \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index b715c8cd..2f27b668 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -39,6 +39,7 @@ from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.tracking import Tracking from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group @@ -121,22 +122,46 @@ def _normalize_config_bsz(self): def _build_dataloader(self): config = self.config # build dataset - self.train_dataset = SFTDataset(parquet_files=config.data.train_files, - tokenizer=self.tokenizer, - prompt_key=config.data.prompt_key, - prompt_dict_keys=config.data.get('prompt_dict_keys', None), - response_key=config.data.response_key, - response_dict_keys=config.data.get('response_dict_keys', None), - max_length=config.data.max_length, - truncation=config.data.truncation) - self.val_dataset = SFTDataset(parquet_files=config.data.val_files, - tokenizer=self.tokenizer, - prompt_key=config.data.prompt_key, - prompt_dict_keys=config.data.get('prompt_dict_keys', None), - response_key=config.data.response_key, - response_dict_keys=config.data.get('response_dict_keys', None), - max_length=config.data.max_length, - truncation=config.data.truncation) + dataset_class = MultiTurnSFTDataset if config.data.get('use_multiturn', False) else SFTDataset + + if dataset_class == MultiTurnSFTDataset: + # Multi-turn dataset uses messages_key instead of prompt/response keys + self.train_dataset = dataset_class( + parquet_files=config.data.train_files, + tokenizer=self.tokenizer, + messages_key=config.data.messages_key, + max_length=config.data.max_length, + truncation=config.data.truncation + ) + self.val_dataset = dataset_class( + parquet_files=config.data.val_files, + tokenizer=self.tokenizer, + messages_key=config.data.messages_key, + max_length=config.data.max_length, + truncation=config.data.truncation + ) + else: + # Single-turn dataset uses prompt/response keys + self.train_dataset = dataset_class( + parquet_files=config.data.train_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), + response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), + max_length=config.data.max_length, + truncation=config.data.truncation + ) + self.val_dataset = dataset_class( + parquet_files=config.data.val_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), + response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), + max_length=config.data.max_length, + truncation=config.data.truncation + ) # build dataloader # Use data parallel rank and size instead of global rank and world size From 46d08d25040fcdc00563768b5f65183fa12b0b99 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 19:21:49 +0000 Subject: [PATCH 07/13] chore: Remove old test file location --- tests/soft/test_multiturn_sft_dataset.py | 169 ----------------------- 1 file changed, 169 deletions(-) delete mode 100644 tests/soft/test_multiturn_sft_dataset.py diff --git a/tests/soft/test_multiturn_sft_dataset.py b/tests/soft/test_multiturn_sft_dataset.py deleted file mode 100644 index 47602313..00000000 --- a/tests/soft/test_multiturn_sft_dataset.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Test the MultiTurnSFTDataset implementation -""" -import os -import pandas as pd -import torch -from transformers import AutoTokenizer -from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset - - -def test_multiturn_sft_dataset(): - print("Starting test...") - # Create a temporary parquet file with test data - test_data = { - 'messages': [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."}, - {"role": "user", "content": "And what is 4+4?"}, - {"role": "assistant", "content": "4+4 equals 8."} - ], - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - {"role": "assistant", "content": "Why did the chicken cross the road?"}, - {"role": "user", "content": "Why?"}, - {"role": "assistant", "content": "To get to the other side!"} - ] - ] - } - - # Create test directory if it doesn't exist - os.makedirs('test_data', exist_ok=True) - test_file = 'test_data/test.parquet' - - # Save test data to parquet - df = pd.DataFrame(test_data) - df.to_parquet(test_file) - - # Initialize tokenizer and dataset - tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') - dataset = MultiTurnSFTDataset( - parquet_files=test_file, - tokenizer=tokenizer, - max_length=512 - ) - - # Test 1: Dataset Length - assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" - - # Get items for testing - item0 = dataset[0] # Math conversation - item1 = dataset[1] # Joke conversation - - # Test 2: Required Keys and Types - required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] - for key in required_keys: - assert key in item0, f"Missing key {key} in dataset item" - assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" - assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" - - # Test 3: Shape Consistency - assert item0['loss_mask'].shape == item0['input_ids'].shape, \ - "Loss mask shape doesn't match input_ids shape" - assert item0['attention_mask'].shape == item0['input_ids'].shape, \ - "Attention mask shape doesn't match input_ids shape" - assert item0['position_ids'].shape == item0['input_ids'].shape, \ - "Position IDs shape doesn't match input_ids shape" - - # Test 4: Loss Mask Pattern - Math Conversation - loss_mask0 = item0['loss_mask'] - input_ids0 = item0['input_ids'] - - # Find assistant response positions - assistant_positions0 = torch.where(loss_mask0 == 1)[0] - assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" - - # Decode and verify assistant responses - assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) - print(f"Math conversation assistant text: {assistant_text0}") - assert "2+2 equals 4" in assistant_text0, "First assistant response not found" - assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" - - # Test 5: Loss Mask Pattern - Joke Conversation - loss_mask1 = item1['loss_mask'] - input_ids1 = item1['input_ids'] - - # Find assistant response positions - assistant_positions1 = torch.where(loss_mask1 == 1)[0] - assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" - - # Decode and verify assistant responses - assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) - print(f"Joke conversation assistant text: {assistant_text1}") - assert "chicken cross the road" in assistant_text1, "First assistant response not found" - assert "other side" in assistant_text1, "Second assistant response not found" - - # Test 6: Attention Mask Pattern - attention_mask0 = item0['attention_mask'] - sequence_length = torch.sum(attention_mask0) - assert sequence_length > 0, "No tokens marked as attended in attention mask" - assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" - if sequence_length < len(attention_mask0): - assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" - - # Test 7: Position IDs Pattern - position_ids0 = item0['position_ids'] - assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ - "Position IDs not sequential for non-padded tokens" - if sequence_length < len(position_ids0): - assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" - - # Test 8: Verify loss mask for assistant responses - # Get the full conversation text - full_text = tokenizer.decode(input_ids0) - print(f"\nFull conversation text:\n{full_text}") - - # Get the assistant responses - assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) - print(f"\nAssistant responses (from loss mask):\n{assistant_text}") - - # Verify that loss mask is set for all assistant responses - for msg in test_data['messages'][0]: # First conversation - if msg['role'] == 'assistant': - # The content should appear in the masked text - assert msg['content'] in assistant_text, \ - f"Assistant message '{msg['content']}' not found in masked text" - - # The content should NOT appear in the non-masked text - non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - assert msg['content'] not in non_assistant_text, \ - f"Assistant message '{msg['content']}' found in non-assistant text" - - # Test 9: Verify non-assistant parts have loss_mask=0 - # Get non-assistant text - non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") - - # Verify that system and user messages are in the non-assistant text - for msg in test_data['messages'][0]: # First conversation - if msg['role'] in ['system', 'user']: - assert msg['content'] in non_assistant_text, \ - f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" - - # And verify they're NOT in the assistant text - assert msg['content'] not in assistant_text, \ - f"{msg['role'].title()} message '{msg['content']}' found in assistant text" - - # Test 10: Verify padding behavior - small_dataset = MultiTurnSFTDataset( - parquet_files=test_file, - tokenizer=tokenizer, - max_length=1024 # Larger than needed to test padding - ) - padded_item = small_dataset[0] - - # Get actual sequence length (before padding) - actual_length = torch.sum(padded_item['attention_mask']) - - # Verify padding tokens - assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ - "Padding tokens not set correctly" - assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ - "Attention mask not set correctly for padding" - assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ - "Loss mask not set correctly for padding" - - print("All tests passed!") \ No newline at end of file From 9e904277bef48402b33bc8d10bc63dbdd06f5909 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 19:24:15 +0000 Subject: [PATCH 08/13] feat: Add multi-turn config defaults - Add use_multiturn flag (default: false) - Add messages_key for multi-turn mode (default: messages) - Group single-turn and multi-turn settings --- verl/trainer/config/sft_trainer.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index b0b10055..200db09c 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -4,8 +4,12 @@ data: micro_batch_size_per_gpu: 4 # this is also val batch size train_files: ~/data/gsm8k/train.parquet val_files: ~/data/gsm8k/test.parquet + # Single-turn settings prompt_key: question response_key: answer + # Multi-turn settings + use_multiturn: false # Set to true to use multi-turn dataset + messages_key: messages # Key for messages list in multi-turn mode max_length: 1024 truncation: error balance_dp_token: False From 8432ca1920b2fe97b32073bc73f958f2b41b3dcc Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 3 Feb 2025 19:59:31 +0000 Subject: [PATCH 09/13] feat: Update multi-turn examples - Add OpenHands SFT dataset preprocessing script - Add token length limit (32k) for conversations - Move multi-turn example to tests/sft - Add train/test split and statistics --- .github/workflows/unit_tests.yml | 49 ----------- examples/data_preprocess/multiturn.py | 117 ++++++++++++++++++-------- tests/sft/run_sft_multiturn.sh | 30 +++++++ 3 files changed, 111 insertions(+), 85 deletions(-) delete mode 100644 .github/workflows/unit_tests.yml create mode 100755 tests/sft/run_sft_multiturn.sh diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml deleted file mode 100644 index d7585393..00000000 --- a/.github/workflows/unit_tests.yml +++ /dev/null @@ -1,49 +0,0 @@ -name: Unit Tests - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - main - paths: - - "**/*.py" - - .github/workflows/unit_tests.yml - - "tests/sft/unit/*.py" - pull_request: - branches: - - main - paths: - - "**/*.py" - - .github/workflows/unit_tests.yml - - "tests/sft/unit/*.py" - -jobs: - unit_tests: - runs-on: ubuntu-latest - container: - image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Install dependencies - run: | - pip3 install -e .[test] - pip3 install pytest pytest-cov - - - name: Run unit tests - run: | - pytest tests/sft/unit/ -v --cov=verl --cov-report=xml - - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - with: - files: ./coverage.xml - flags: unittests - name: codecov-umbrella - fail_ci_if_error: true \ No newline at end of file diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py index 135ad514..9899bcf9 100644 --- a/examples/data_preprocess/multiturn.py +++ b/examples/data_preprocess/multiturn.py @@ -1,74 +1,119 @@ """ -Example script for preprocessing multi-turn conversation data into parquet format +Preprocess OpenHands SFT Trajectories dataset into parquet format for multi-turn training """ import os import argparse import datasets from verl.utils.hdfs_io import copy, makedirs +from transformers import AutoTokenizer -def process_conversation(example, idx, split): +def count_tokens(text, tokenizer): + """Count the number of tokens in a text""" + return len(tokenizer(text).input_ids) + + +def process_conversation(example, idx, split, tokenizer, max_tokens=32000): """Convert a conversation into the expected format""" messages = [] + total_tokens = 0 - # Add system message if present - if example.get('system', ''): - messages.append({ - "role": "system", - "content": example['system'] - }) + # Add system message + system_msg = { + "role": "system", + "content": "You are a helpful assistant that can understand and generate code." + } + total_tokens += count_tokens(system_msg["content"], tokenizer) + messages.append(system_msg) - # Add conversation turns - for turn in example['conversation']: - messages.append({ - "role": turn['role'], - "content": turn['content'] - }) + # Process each turn + for i in range(len(example['human'])): + # Add human message + human_msg = { + "role": "user", + "content": example['human'][i] + } + human_tokens = count_tokens(human_msg["content"], tokenizer) + + # Add assistant message + assistant_msg = { + "role": "assistant", + "content": example['assistant'][i] + } + assistant_tokens = count_tokens(assistant_msg["content"], tokenizer) + + # Check if adding these messages would exceed token limit + if total_tokens + human_tokens + assistant_tokens > max_tokens: + break + + total_tokens += human_tokens + assistant_tokens + messages.append(human_msg) + messages.append(assistant_msg) - # Return the processed data - return { - "data_source": "multiturn_example", - "messages": messages, - "extra_info": { - 'split': split, - 'index': idx, - 'original': example + # Only return if we have at least one complete turn + if len(messages) >= 3: # system + at least one human-assistant pair + return { + "data_source": "openhands_sft_trajectories", + "messages": messages, + "extra_info": { + 'split': split, + 'index': idx, + 'total_tokens': total_tokens, + 'original_id': example.get('id', None) + } } - } + return None if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--local_dir', default='~/data/multiturn') parser.add_argument('--hdfs_dir', default=None) + parser.add_argument('--max_tokens', type=int, default=32000) args = parser.parse_args() - # Load your dataset here - # This is just an example - replace with your actual data loading - dataset = datasets.load_dataset('your_dataset_name') + # Load tokenizer for token counting + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct') + # Load OpenHands dataset + dataset = datasets.load_dataset('SWE-Gym/OpenHands-SFT-Trajectories') + + # Split into train/test (90/10 split) + dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) train_dataset = dataset['train'] test_dataset = dataset['test'] # Process the datasets train_dataset = train_dataset.map( - function=lambda x, i: process_conversation(x, i, 'train'), - with_indices=True + function=lambda x, i: process_conversation(x, i, 'train', tokenizer, args.max_tokens), + with_indices=True, + remove_columns=train_dataset.column_names ) test_dataset = test_dataset.map( - function=lambda x, i: process_conversation(x, i, 'test'), - with_indices=True + function=lambda x, i: process_conversation(x, i, 'test', tokenizer, args.max_tokens), + with_indices=True, + remove_columns=test_dataset.column_names ) - # Save to parquet files - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir + # Filter out None values (conversations that were too long) + train_dataset = train_dataset.filter(lambda x: x is not None) + test_dataset = test_dataset.filter(lambda x: x is not None) + + # Create output directory + local_dir = os.path.expanduser(args.local_dir) + os.makedirs(local_dir, exist_ok=True) + # Save to parquet files train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) - if hdfs_dir is not None: - makedirs(hdfs_dir) - copy(src=local_dir, dst=hdfs_dir) \ No newline at end of file + if args.hdfs_dir is not None: + makedirs(args.hdfs_dir) + copy(src=local_dir, dst=args.hdfs_dir) + + # Print statistics + print(f"Train dataset size: {len(train_dataset)}") + print(f"Test dataset size: {len(test_dataset)}") + print(f"Data saved to {local_dir}") \ No newline at end of file diff --git a/tests/sft/run_sft_multiturn.sh b/tests/sft/run_sft_multiturn.sh new file mode 100755 index 00000000..48d2a2fb --- /dev/null +++ b/tests/sft/run_sft_multiturn.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/multiturn/train.parquet \ + data.val_files=$HOME/data/multiturn/test.parquet \ + data.use_multiturn=true \ + data.messages_key=messages \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file From d4685ba174f6b64176e104547d46a7e24b29c633 Mon Sep 17 00:00:00 2001 From: Xingyao Wang Date: Mon, 3 Feb 2025 21:35:31 -0500 Subject: [PATCH 10/13] move file --- .../unit => verl/utils/dataset}/test_multiturn_sft_dataset.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{sft/unit => verl/utils/dataset}/test_multiturn_sft_dataset.py (100%) diff --git a/tests/sft/unit/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py similarity index 100% rename from tests/sft/unit/test_multiturn_sft_dataset.py rename to tests/verl/utils/dataset/test_multiturn_sft_dataset.py From 8f3e5c62aee86b7cd02db3fc684a49d271e5e47d Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 13 Feb 2025 15:03:12 +0000 Subject: [PATCH 11/13] Apply code formatting --- examples/data_preprocess/multiturn.py | 55 ++++----- .../dataset/test_multiturn_sft_dataset.py | 108 ++++++++++-------- verl/trainer/fsdp_sft_trainer.py | 62 +++++----- verl/utils/dataset/multiturn_sft_dataset.py | 51 +++++---- 4 files changed, 138 insertions(+), 138 deletions(-) diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py index 9899bcf9..4e513a73 100644 --- a/examples/data_preprocess/multiturn.py +++ b/examples/data_preprocess/multiturn.py @@ -18,39 +18,30 @@ def process_conversation(example, idx, split, tokenizer, max_tokens=32000): """Convert a conversation into the expected format""" messages = [] total_tokens = 0 - + # Add system message - system_msg = { - "role": "system", - "content": "You are a helpful assistant that can understand and generate code." - } + system_msg = {"role": "system", "content": "You are a helpful assistant that can understand and generate code."} total_tokens += count_tokens(system_msg["content"], tokenizer) messages.append(system_msg) - + # Process each turn for i in range(len(example['human'])): # Add human message - human_msg = { - "role": "user", - "content": example['human'][i] - } + human_msg = {"role": "user", "content": example['human'][i]} human_tokens = count_tokens(human_msg["content"], tokenizer) - + # Add assistant message - assistant_msg = { - "role": "assistant", - "content": example['assistant'][i] - } + assistant_msg = {"role": "assistant", "content": example['assistant'][i]} assistant_tokens = count_tokens(assistant_msg["content"], tokenizer) - + # Check if adding these messages would exceed token limit if total_tokens + human_tokens + assistant_tokens > max_tokens: break - + total_tokens += human_tokens + assistant_tokens messages.append(human_msg) messages.append(assistant_msg) - + # Only return if we have at least one complete turn if len(messages) >= 3: # system + at least one human-assistant pair return { @@ -71,49 +62,47 @@ def process_conversation(example, idx, split, tokenizer, max_tokens=32000): parser.add_argument('--local_dir', default='~/data/multiturn') parser.add_argument('--hdfs_dir', default=None) parser.add_argument('--max_tokens', type=int, default=32000) - + args = parser.parse_args() - + # Load tokenizer for token counting tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct') - + # Load OpenHands dataset dataset = datasets.load_dataset('SWE-Gym/OpenHands-SFT-Trajectories') - + # Split into train/test (90/10 split) dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) train_dataset = dataset['train'] test_dataset = dataset['test'] - + # Process the datasets train_dataset = train_dataset.map( function=lambda x, i: process_conversation(x, i, 'train', tokenizer, args.max_tokens), with_indices=True, - remove_columns=train_dataset.column_names - ) + remove_columns=train_dataset.column_names) test_dataset = test_dataset.map( function=lambda x, i: process_conversation(x, i, 'test', tokenizer, args.max_tokens), with_indices=True, - remove_columns=test_dataset.column_names - ) - + remove_columns=test_dataset.column_names) + # Filter out None values (conversations that were too long) train_dataset = train_dataset.filter(lambda x: x is not None) test_dataset = test_dataset.filter(lambda x: x is not None) - + # Create output directory local_dir = os.path.expanduser(args.local_dir) os.makedirs(local_dir, exist_ok=True) - + # Save to parquet files train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) - + if args.hdfs_dir is not None: makedirs(args.hdfs_dir) copy(src=local_dir, dst=args.hdfs_dir) - + # Print statistics print(f"Train dataset size: {len(train_dataset)}") print(f"Test dataset size: {len(test_dataset)}") - print(f"Data saved to {local_dir}") \ No newline at end of file + print(f"Data saved to {local_dir}") diff --git a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py index 47602313..fbd71e8d 100644 --- a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py +++ b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py @@ -12,54 +12,66 @@ def test_multiturn_sft_dataset(): print("Starting test...") # Create a temporary parquet file with test data test_data = { - 'messages': [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."}, - {"role": "user", "content": "And what is 4+4?"}, - {"role": "assistant", "content": "4+4 equals 8."} - ], - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - {"role": "assistant", "content": "Why did the chicken cross the road?"}, - {"role": "user", "content": "Why?"}, - {"role": "assistant", "content": "To get to the other side!"} - ] - ] + 'messages': [[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": "assistant", + "content": "2+2 equals 4." + }, { + "role": "user", + "content": "And what is 4+4?" + }, { + "role": "assistant", + "content": "4+4 equals 8." + }], + [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Tell me a joke." + }, { + "role": "assistant", + "content": "Why did the chicken cross the road?" + }, { + "role": "user", + "content": "Why?" + }, { + "role": "assistant", + "content": "To get to the other side!" + }]] } - + # Create test directory if it doesn't exist os.makedirs('test_data', exist_ok=True) test_file = 'test_data/test.parquet' - + # Save test data to parquet df = pd.DataFrame(test_data) df.to_parquet(test_file) - + # Initialize tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') - dataset = MultiTurnSFTDataset( - parquet_files=test_file, - tokenizer=tokenizer, - max_length=512 - ) - + dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, max_length=512) + # Test 1: Dataset Length assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" - + # Get items for testing item0 = dataset[0] # Math conversation item1 = dataset[1] # Joke conversation - + # Test 2: Required Keys and Types required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] for key in required_keys: assert key in item0, f"Missing key {key} in dataset item" assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" - + # Test 3: Shape Consistency assert item0['loss_mask'].shape == item0['input_ids'].shape, \ "Loss mask shape doesn't match input_ids shape" @@ -67,35 +79,35 @@ def test_multiturn_sft_dataset(): "Attention mask shape doesn't match input_ids shape" assert item0['position_ids'].shape == item0['input_ids'].shape, \ "Position IDs shape doesn't match input_ids shape" - + # Test 4: Loss Mask Pattern - Math Conversation loss_mask0 = item0['loss_mask'] input_ids0 = item0['input_ids'] - + # Find assistant response positions assistant_positions0 = torch.where(loss_mask0 == 1)[0] assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" - + # Decode and verify assistant responses assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) print(f"Math conversation assistant text: {assistant_text0}") assert "2+2 equals 4" in assistant_text0, "First assistant response not found" assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" - + # Test 5: Loss Mask Pattern - Joke Conversation loss_mask1 = item1['loss_mask'] input_ids1 = item1['input_ids'] - + # Find assistant response positions assistant_positions1 = torch.where(loss_mask1 == 1)[0] assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" - + # Decode and verify assistant responses assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) print(f"Joke conversation assistant text: {assistant_text1}") assert "chicken cross the road" in assistant_text1, "First assistant response not found" assert "other side" in assistant_text1, "Second assistant response not found" - + # Test 6: Attention Mask Pattern attention_mask0 = item0['attention_mask'] sequence_length = torch.sum(attention_mask0) @@ -103,50 +115,50 @@ def test_multiturn_sft_dataset(): assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" if sequence_length < len(attention_mask0): assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" - + # Test 7: Position IDs Pattern position_ids0 = item0['position_ids'] assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ "Position IDs not sequential for non-padded tokens" if sequence_length < len(position_ids0): assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" - + # Test 8: Verify loss mask for assistant responses # Get the full conversation text full_text = tokenizer.decode(input_ids0) print(f"\nFull conversation text:\n{full_text}") - + # Get the assistant responses assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) print(f"\nAssistant responses (from loss mask):\n{assistant_text}") - + # Verify that loss mask is set for all assistant responses for msg in test_data['messages'][0]: # First conversation if msg['role'] == 'assistant': # The content should appear in the masked text assert msg['content'] in assistant_text, \ f"Assistant message '{msg['content']}' not found in masked text" - + # The content should NOT appear in the non-masked text non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) assert msg['content'] not in non_assistant_text, \ f"Assistant message '{msg['content']}' found in non-assistant text" - + # Test 9: Verify non-assistant parts have loss_mask=0 # Get non-assistant text non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") - + # Verify that system and user messages are in the non-assistant text for msg in test_data['messages'][0]: # First conversation if msg['role'] in ['system', 'user']: assert msg['content'] in non_assistant_text, \ f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" - + # And verify they're NOT in the assistant text assert msg['content'] not in assistant_text, \ f"{msg['role'].title()} message '{msg['content']}' found in assistant text" - + # Test 10: Verify padding behavior small_dataset = MultiTurnSFTDataset( parquet_files=test_file, @@ -154,10 +166,10 @@ def test_multiturn_sft_dataset(): max_length=1024 # Larger than needed to test padding ) padded_item = small_dataset[0] - + # Get actual sequence length (before padding) actual_length = torch.sum(padded_item['attention_mask']) - + # Verify padding tokens assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ "Padding tokens not set correctly" @@ -165,5 +177,5 @@ def test_multiturn_sft_dataset(): "Attention mask not set correctly for padding" assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ "Loss mask not set correctly for padding" - - print("All tests passed!") \ No newline at end of file + + print("All tests passed!") diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 2f27b668..07c69036 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -123,45 +123,37 @@ def _build_dataloader(self): config = self.config # build dataset dataset_class = MultiTurnSFTDataset if config.data.get('use_multiturn', False) else SFTDataset - + if dataset_class == MultiTurnSFTDataset: # Multi-turn dataset uses messages_key instead of prompt/response keys - self.train_dataset = dataset_class( - parquet_files=config.data.train_files, - tokenizer=self.tokenizer, - messages_key=config.data.messages_key, - max_length=config.data.max_length, - truncation=config.data.truncation - ) - self.val_dataset = dataset_class( - parquet_files=config.data.val_files, - tokenizer=self.tokenizer, - messages_key=config.data.messages_key, - max_length=config.data.max_length, - truncation=config.data.truncation - ) + self.train_dataset = dataset_class(parquet_files=config.data.train_files, + tokenizer=self.tokenizer, + messages_key=config.data.messages_key, + max_length=config.data.max_length, + truncation=config.data.truncation) + self.val_dataset = dataset_class(parquet_files=config.data.val_files, + tokenizer=self.tokenizer, + messages_key=config.data.messages_key, + max_length=config.data.max_length, + truncation=config.data.truncation) else: # Single-turn dataset uses prompt/response keys - self.train_dataset = dataset_class( - parquet_files=config.data.train_files, - tokenizer=self.tokenizer, - prompt_key=config.data.prompt_key, - prompt_dict_keys=config.data.get('prompt_dict_keys', None), - response_key=config.data.response_key, - response_dict_keys=config.data.get('response_dict_keys', None), - max_length=config.data.max_length, - truncation=config.data.truncation - ) - self.val_dataset = dataset_class( - parquet_files=config.data.val_files, - tokenizer=self.tokenizer, - prompt_key=config.data.prompt_key, - prompt_dict_keys=config.data.get('prompt_dict_keys', None), - response_key=config.data.response_key, - response_dict_keys=config.data.get('response_dict_keys', None), - max_length=config.data.max_length, - truncation=config.data.truncation - ) + self.train_dataset = dataset_class(parquet_files=config.data.train_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), + response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), + max_length=config.data.max_length, + truncation=config.data.truncation) + self.val_dataset = dataset_class(parquet_files=config.data.val_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), + response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), + max_length=config.data.max_length, + truncation=config.data.truncation) # build dataloader # Use data parallel rank and size instead of global rank and world size diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 3a7a5b34..901d1715 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -19,12 +19,13 @@ class MultiTurnSFTDataset(Dataset): Dataset for multi-turn conversations where each assistant response should be trained """ - def __init__(self, - parquet_files: Union[str, List[str]], - tokenizer, - messages_key='messages', # Key for the messages list in the parquet file - max_length=1024, - truncation='error'): + def __init__( + self, + parquet_files: Union[str, List[str]], + tokenizer, + messages_key='messages', # Key for the messages list in the parquet file + max_length=1024, + truncation='error'): assert truncation in ['error', 'left', 'right'] self.truncation = truncation @@ -46,6 +47,7 @@ def _download(self): self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) def _read_files_and_process(self): + def series_to_item(ls): import pandas, numpy while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: @@ -57,7 +59,7 @@ def series_to_item(ls): dataframe = pd.read_parquet(parquet_file) dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) - + # Extract messages list from dataframe self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() @@ -69,27 +71,34 @@ def __getitem__(self, item): messages = self.messages[item] # First, get the full conversation tokens - full_tokens = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) + full_tokens = tokenizer.apply_chat_template(messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) input_ids = full_tokens[0] # The output is already a tensor attention_mask = torch.ones_like(input_ids) - + # Create loss mask by identifying assistant responses loss_mask = torch.zeros_like(input_ids, dtype=torch.long) - + # Process each message to find assistant responses current_length = 0 for i, msg in enumerate(messages): # Get tokens for messages up to this point to find the start position - prefix_messages = messages[:i+1] - prefix_tokens = tokenizer.apply_chat_template(prefix_messages, tokenize=True, return_tensors='pt', add_generation_prompt=False) - + prefix_messages = messages[:i + 1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) + # Get tokens for messages up to previous point - prev_tokens = tokenizer.apply_chat_template(messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None - + prev_tokens = tokenizer.apply_chat_template( + messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None + # Calculate start and end positions start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 end_pos = prefix_tokens[0].shape[0] - + # If this is an assistant message, set loss mask if msg['role'] == 'assistant': loss_mask[start_pos:end_pos] = 1 @@ -100,11 +109,9 @@ def __getitem__(self, item): # Pad sequences pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), - dtype=input_ids.dtype) * pad_token_id - padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), - dtype=attention_mask.dtype) - padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), - dtype=loss_mask.dtype) + dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) input_ids = torch.cat((input_ids, padded_input_ids)) attention_mask = torch.cat((attention_mask, padded_attention_mask)) @@ -133,4 +140,4 @@ def __getitem__(self, item): 'attention_mask': attention_mask, 'position_ids': position_ids, 'loss_mask': loss_mask - } \ No newline at end of file + } From e34b93225dfa72a364aa62a712c98eb27468ffd3 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 13 Feb 2025 15:08:41 +0000 Subject: [PATCH 12/13] Add license headers to multiturn-related files --- examples/data_preprocess/multiturn.py | 137 ++++++++++++ .../dataset/test_multiturn_sft_dataset.py | 210 ++++++++++++++++++ verl/utils/dataset/multiturn_sft_dataset.py | 172 ++++++++++++++ 3 files changed, 519 insertions(+) diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py index 4e513a73..7ee81c1c 100644 --- a/examples/data_preprocess/multiturn.py +++ b/examples/data_preprocess/multiturn.py @@ -1,3 +1,140 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess OpenHands SFT Trajectories dataset into parquet format for multi-turn training +""" + +import os +import argparse +import datasets +from verl.utils.hdfs_io import copy, makedirs +from transformers import AutoTokenizer + + +def count_tokens(text, tokenizer): + """Count the number of tokens in a text""" + return len(tokenizer(text).input_ids) + + +def process_conversation(example, idx, split, tokenizer, max_tokens=32000): + """Convert a conversation into the expected format""" + messages = [] + total_tokens = 0 + + # Add system message + system_msg = {"role": "system", "content": "You are a helpful assistant that can understand and generate code."} + total_tokens += count_tokens(system_msg["content"], tokenizer) + messages.append(system_msg) + + # Process each turn + for i in range(len(example['human'])): + # Add human message + human_msg = {"role": "user", "content": example['human'][i]} + human_tokens = count_tokens(human_msg["content"], tokenizer) + + # Add assistant message + assistant_msg = {"role": "assistant", "content": example['assistant'][i]} + assistant_tokens = count_tokens(assistant_msg["content"], tokenizer) + + # Check if adding these messages would exceed token limit + if total_tokens + human_tokens + assistant_tokens > max_tokens: + break + + total_tokens += human_tokens + assistant_tokens + messages.append(human_msg) + messages.append(assistant_msg) + + # Only return if we have at least one complete turn + if len(messages) >= 3: # system + at least one human-assistant pair + return { + "data_source": "openhands_sft_trajectories", + "messages": messages, + "extra_info": { + 'split': split, + 'index': idx, + 'total_tokens': total_tokens, + 'original_id': example.get('id', None) + } + } + return None + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--local_dir', default='~/data/multiturn') + parser.add_argument('--hdfs_dir', default=None) + parser.add_argument('--max_tokens', type=int, default=32000) + + args = parser.parse_args() + + # Load tokenizer for token counting + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct') + + # Load OpenHands dataset + dataset = datasets.load_dataset('SWE-Gym/OpenHands-SFT-Trajectories') + + # Split into train/test (90/10 split) + dataset = dataset['train'].train_test_split(test_size=0.1, seed=42) + train_dataset = dataset['train'] + test_dataset = dataset['test'] + + # Process the datasets + train_dataset = train_dataset.map( + function=lambda x, i: process_conversation(x, i, 'train', tokenizer, args.max_tokens), + with_indices=True, + remove_columns=train_dataset.column_names) + test_dataset = test_dataset.map( + function=lambda x, i: process_conversation(x, i, 'test', tokenizer, args.max_tokens), + with_indices=True, + remove_columns=test_dataset.column_names) + + # Filter out None values (conversations that were too long) + train_dataset = train_dataset.filter(lambda x: x is not None) + test_dataset = test_dataset.filter(lambda x: x is not None) + + # Create output directory + local_dir = os.path.expanduser(args.local_dir) + os.makedirs(local_dir, exist_ok=True) + + # Save to parquet files + train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) + test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) + + if args.hdfs_dir is not None: + makedirs(args.hdfs_dir) + copy(src=local_dir, dst=args.hdfs_dir) + + # Print statistics + print(f"Train dataset size: {len(train_dataset)}") + print(f"Test dataset size: {len(test_dataset)}") + print(f"Data saved to {local_dir}") +EOF > examples/data_preprocess/multiturn.py +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Preprocess OpenHands SFT Trajectories dataset into parquet format for multi-turn training """ diff --git a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py index fbd71e8d..3c30a91e 100644 --- a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py +++ b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py @@ -1,3 +1,213 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test the MultiTurnSFTDataset implementation +""" +import os +import pandas as pd +import torch +from transformers import AutoTokenizer +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset + + +def test_multiturn_sft_dataset(): + print("Starting test...") + # Create a temporary parquet file with test data + test_data = { + 'messages': [[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": "assistant", + "content": "2+2 equals 4." + }, { + "role": "user", + "content": "And what is 4+4?" + }, { + "role": "assistant", + "content": "4+4 equals 8." + }], + [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Tell me a joke." + }, { + "role": "assistant", + "content": "Why did the chicken cross the road?" + }, { + "role": "user", + "content": "Why?" + }, { + "role": "assistant", + "content": "To get to the other side!" + }]] + } + + # Create test directory if it doesn't exist + os.makedirs('test_data', exist_ok=True) + test_file = 'test_data/test.parquet' + + # Save test data to parquet + df = pd.DataFrame(test_data) + df.to_parquet(test_file) + + # Initialize tokenizer and dataset + tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-Coder-7B-Instruct') + dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, max_length=512) + + # Test 1: Dataset Length + assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" + + # Get items for testing + item0 = dataset[0] # Math conversation + item1 = dataset[1] # Joke conversation + + # Test 2: Required Keys and Types + required_keys = ['input_ids', 'attention_mask', 'position_ids', 'loss_mask'] + for key in required_keys: + assert key in item0, f"Missing key {key} in dataset item" + assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" + assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" + + # Test 3: Shape Consistency + assert item0['loss_mask'].shape == item0['input_ids'].shape, \ + "Loss mask shape doesn't match input_ids shape" + assert item0['attention_mask'].shape == item0['input_ids'].shape, \ + "Attention mask shape doesn't match input_ids shape" + assert item0['position_ids'].shape == item0['input_ids'].shape, \ + "Position IDs shape doesn't match input_ids shape" + + # Test 4: Loss Mask Pattern - Math Conversation + loss_mask0 = item0['loss_mask'] + input_ids0 = item0['input_ids'] + + # Find assistant response positions + assistant_positions0 = torch.where(loss_mask0 == 1)[0] + assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"Math conversation assistant text: {assistant_text0}") + assert "2+2 equals 4" in assistant_text0, "First assistant response not found" + assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" + + # Test 5: Loss Mask Pattern - Joke Conversation + loss_mask1 = item1['loss_mask'] + input_ids1 = item1['input_ids'] + + # Find assistant response positions + assistant_positions1 = torch.where(loss_mask1 == 1)[0] + assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" + + # Decode and verify assistant responses + assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) + print(f"Joke conversation assistant text: {assistant_text1}") + assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert "other side" in assistant_text1, "Second assistant response not found" + + # Test 6: Attention Mask Pattern + attention_mask0 = item0['attention_mask'] + sequence_length = torch.sum(attention_mask0) + assert sequence_length > 0, "No tokens marked as attended in attention mask" + assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + if sequence_length < len(attention_mask0): + assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + + # Test 7: Position IDs Pattern + position_ids0 = item0['position_ids'] + assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), \ + "Position IDs not sequential for non-padded tokens" + if sequence_length < len(position_ids0): + assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + + # Test 8: Verify loss mask for assistant responses + # Get the full conversation text + full_text = tokenizer.decode(input_ids0) + print(f"\nFull conversation text:\n{full_text}") + + # Get the assistant responses + assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) + print(f"\nAssistant responses (from loss mask):\n{assistant_text}") + + # Verify that loss mask is set for all assistant responses + for msg in test_data['messages'][0]: # First conversation + if msg['role'] == 'assistant': + # The content should appear in the masked text + assert msg['content'] in assistant_text, \ + f"Assistant message '{msg['content']}' not found in masked text" + + # The content should NOT appear in the non-masked text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + assert msg['content'] not in non_assistant_text, \ + f"Assistant message '{msg['content']}' found in non-assistant text" + + # Test 9: Verify non-assistant parts have loss_mask=0 + # Get non-assistant text + non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) + print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") + + # Verify that system and user messages are in the non-assistant text + for msg in test_data['messages'][0]: # First conversation + if msg['role'] in ['system', 'user']: + assert msg['content'] in non_assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" + + # And verify they're NOT in the assistant text + assert msg['content'] not in assistant_text, \ + f"{msg['role'].title()} message '{msg['content']}' found in assistant text" + + # Test 10: Verify padding behavior + small_dataset = MultiTurnSFTDataset( + parquet_files=test_file, + tokenizer=tokenizer, + max_length=1024 # Larger than needed to test padding + ) + padded_item = small_dataset[0] + + # Get actual sequence length (before padding) + actual_length = torch.sum(padded_item['attention_mask']) + + # Verify padding tokens + assert torch.all(padded_item['input_ids'][actual_length:] == tokenizer.pad_token_id), \ + "Padding tokens not set correctly" + assert torch.all(padded_item['attention_mask'][actual_length:] == 0), \ + "Attention mask not set correctly for padding" + assert torch.all(padded_item['loss_mask'][actual_length:] == 0), \ + "Loss mask not set correctly for padding" + + print("All tests passed!") +EOF > tests/verl/utils/dataset/test_multiturn_sft_dataset.py +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Test the MultiTurnSFTDataset implementation """ diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 901d1715..bcc0d6f0 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -1,3 +1,175 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-turn SFT dataset that supports training on conversation data with multiple turns +""" + +from typing import List, Union + +import pandas as pd +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask +from verl.utils import hf_tokenizer + + +class MultiTurnSFTDataset(Dataset): + """ + Dataset for multi-turn conversations where each assistant response should be trained + """ + + def __init__( + self, + parquet_files: Union[str, List[str]], + tokenizer, + messages_key='messages', # Key for the messages list in the parquet file + max_length=1024, + truncation='error'): + assert truncation in ['error', 'left', 'right'] + self.truncation = truncation + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + self.messages_key = messages_key + self.max_length = max_length + + self._download() + self._read_files_and_process() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + + def _read_files_and_process(self): + + def series_to_item(ls): + import pandas, numpy + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + # Extract messages list from dataframe + self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() + + def __len__(self): + return len(self.messages) + + def __getitem__(self, item): + tokenizer = self.tokenizer + messages = self.messages[item] + + # First, get the full conversation tokens + full_tokens = tokenizer.apply_chat_template(messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) + input_ids = full_tokens[0] # The output is already a tensor + attention_mask = torch.ones_like(input_ids) + + # Create loss mask by identifying assistant responses + loss_mask = torch.zeros_like(input_ids, dtype=torch.long) + + # Process each message to find assistant responses + current_length = 0 + for i, msg in enumerate(messages): + # Get tokens for messages up to this point to find the start position + prefix_messages = messages[:i + 1] + prefix_tokens = tokenizer.apply_chat_template(prefix_messages, + tokenize=True, + return_tensors='pt', + add_generation_prompt=False) + + # Get tokens for messages up to previous point + prev_tokens = tokenizer.apply_chat_template( + messages[:i], tokenize=True, return_tensors='pt', add_generation_prompt=False) if i > 0 else None + + # Calculate start and end positions + start_pos = prev_tokens[0].shape[0] if prev_tokens is not None else 0 + end_pos = prefix_tokens[0].shape[0] + + # If this is an assistant message, set loss mask + if msg['role'] == 'assistant': + loss_mask[start_pos:end_pos] = 1 + + # Handle sequence length + sequence_length = input_ids.shape[0] + if sequence_length < self.max_length: + # Pad sequences + pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 + padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), + dtype=input_ids.dtype) * pad_token_id + padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_loss_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=loss_mask.dtype) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + loss_mask = torch.cat((loss_mask, padded_loss_mask)) + elif sequence_length > self.max_length: + if self.truncation == 'left': + input_ids = input_ids[-self.max_length:] + attention_mask = attention_mask[-self.max_length:] + loss_mask = loss_mask[-self.max_length:] + elif self.truncation == 'right': + input_ids = input_ids[:self.max_length] + attention_mask = attention_mask[:self.max_length] + loss_mask = loss_mask[:self.max_length] + elif self.truncation == 'error': + raise ValueError(f'{sequence_length=} is larger than {self.max_length=}') + else: + raise ValueError(f'Unknown truncation method {self.truncation}') + + # Create position IDs + position_ids = torch.arange(len(input_ids), dtype=torch.long) + # Zero out position IDs for padding + position_ids = position_ids * attention_mask + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'loss_mask': loss_mask + } +EOF > verl/utils/dataset/multiturn_sft_dataset.py +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Multi-turn SFT dataset that supports training on conversation data with multiple turns """ From 8fb9c3bf18ea5b367dac37d9aec9bcbb94f475a0 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 13 Feb 2025 16:40:49 +0000 Subject: [PATCH 13/13] Apply formatting changes to multi-turn related files --- examples/data_preprocess/multiturn.py | 4 +--- tests/verl/utils/dataset/test_multiturn_sft_dataset.py | 6 +++--- verl/utils/dataset/multiturn_sft_dataset.py | 6 +++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/data_preprocess/multiturn.py b/examples/data_preprocess/multiturn.py index 7ee81c1c..b44d0690 100644 --- a/examples/data_preprocess/multiturn.py +++ b/examples/data_preprocess/multiturn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Preprocess OpenHands SFT Trajectories dataset into parquet format for multi-turn training """ @@ -120,7 +119,7 @@ def process_conversation(example, idx, split, tokenizer, max_tokens=32000): print(f"Train dataset size: {len(train_dataset)}") print(f"Test dataset size: {len(test_dataset)}") print(f"Data saved to {local_dir}") -EOF > examples/data_preprocess/multiturn.py +EOF > examples / data_preprocess / multiturn.py # Copyright 2024 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); @@ -134,7 +133,6 @@ def process_conversation(example, idx, split, tokenizer, max_tokens=32000): # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Preprocess OpenHands SFT Trajectories dataset into parquet format for multi-turn training """ diff --git a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py index 3c30a91e..9ea88736 100644 --- a/tests/verl/utils/dataset/test_multiturn_sft_dataset.py +++ b/tests/verl/utils/dataset/test_multiturn_sft_dataset.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Test the MultiTurnSFTDataset implementation """ @@ -193,7 +192,9 @@ def test_multiturn_sft_dataset(): "Loss mask not set correctly for padding" print("All tests passed!") -EOF > tests/verl/utils/dataset/test_multiturn_sft_dataset.py + + +EOF > tests / verl / utils / dataset / test_multiturn_sft_dataset.py # Copyright 2024 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); @@ -207,7 +208,6 @@ def test_multiturn_sft_dataset(): # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Test the MultiTurnSFTDataset implementation """ diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index bcc0d6f0..4e5f2fa0 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Multi-turn SFT dataset that supports training on conversation data with multiple turns """ @@ -155,7 +154,9 @@ def __getitem__(self, item): 'position_ids': position_ids, 'loss_mask': loss_mask } -EOF > verl/utils/dataset/multiturn_sft_dataset.py + + +EOF > verl / utils / dataset / multiturn_sft_dataset.py # Copyright 2024 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); @@ -169,7 +170,6 @@ def __getitem__(self, item): # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Multi-turn SFT dataset that supports training on conversation data with multiple turns """