Skip to content

Commit

Permalink
support modelscope models and datasets (#1481)
Browse files Browse the repository at this point in the history
* support modelscope

* change modelscope args

* remove useless import

* remove useless import

* fix

* wip

* fix

* remove useless code

* add readme

* add some comments

* change print to raise error

* update comment

* Update loader.py

---------

Co-authored-by: Daniel Han <[email protected]>
  • Loading branch information
tastelikefeet and danielhanchen authored Jan 7, 2025
1 parent 83b48a8 commit e0ccfaf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`.

> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community.
```python
from unsloth import FastLanguageModel
Expand Down
12 changes: 10 additions & 2 deletions unsloth-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@
"""

import argparse
import os


def run(args):
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers.utils import strtobool
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
Expand Down Expand Up @@ -86,8 +89,13 @@ def formatting_prompts_func(examples):
texts.append(text)
return {"text": texts}

# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False'))
if use_modelscope:
from modelscope import MsDataset
dataset = MsDataset.load(args.dataset, split="train")
else:
# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Data is formatted and ready!")

Expand Down
19 changes: 19 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
pass
from huggingface_hub import HfFileSystem

# [TODO] Move USE_MODELSCOPE to utils
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`')
pass
pass

# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version, _get_dtype
transformers_version = Version(transformers_version)
Expand Down Expand Up @@ -72,6 +81,11 @@ def from_pretrained(
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)

if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
Expand Down Expand Up @@ -355,6 +369,11 @@ def from_pretrained(
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)

if USE_MODELSCOPE and not os.path.exists(model_name):
from modelscope import snapshot_download
model_name = snapshot_download(model_name)
pass

# First check if it's a normal model via AutoConfig
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
was_disabled = are_progress_bars_disabled()
Expand Down

0 comments on commit e0ccfaf

Please sign in to comment.