Skip to content

Commit

Permalink
Merge pull request #28 from AndrewNLauder/generate_local_models
Browse files Browse the repository at this point in the history
Generate CoreML models from local transformer models or HF repos
  • Loading branch information
atiorh authored Jan 20, 2025
2 parents 03898fd + 7138f59 commit 6dd1371
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 20 deletions.
21 changes: 12 additions & 9 deletions scripts/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def cli():
parser.add_argument(
"--model-version",
required=True,
help="Whisper model version string that matches Hugging Face model hub name, "
"e.g. openai/whisper-tiny.en",
help="Whisper model version string that can be either:\n"
"1. A Hugging Face model hub name (e.g. openai/whisper-tiny.en)\n"
"2. A local directory containing the model files"
)
parser.add_argument(
"--generate-quantized-variants",
Expand Down Expand Up @@ -88,9 +89,7 @@ def cli():
args.test_model_version = args.model_version
args.palettizer_tests = args.generate_quantized_variants
args.context_prefill_tests = args.generate_decoder_context_prefill_data
args.persistent_cache_dir = os.path.join(
args.output_dir, args.model_version.replace("/", "_")
)
args.persistent_cache_dir = args.output_dir
if args.repo_path_suffix is not None:
args.persistent_cache_dir += f"_{args.repo_path_suffix}"

Expand Down Expand Up @@ -135,12 +134,16 @@ def upload_version(local_folder_path, model_version):

# Dump required metadata before upload
for filename in ["config.json", "generation_config.json"]:
with open(hf_hub_download(repo_id=model_version,
filename=filename), "r") as f:
if os.path.exists(model_version): # Local path
config_path = os.path.join(model_version, filename)
else: # HF hub path
config_path = hf_hub_download(repo_id=model_version, filename=filename)

with open(config_path, "r") as f:
model_file = json.load(f)
with open(os.path.join(local_folder_path, filename), "w") as f:
json.dump(model_file, f)
logger.info(f"Copied over {filename} from the original {model_version} repo")
logger.info(f"Copied over {filename} from the original model")

# Get whisperkittools commit hash
wkt_commit_hash = subprocess.run(
Expand Down Expand Up @@ -261,4 +264,4 @@ def get_dir_size(root_dir):
path = os.path.join(parent, f)
if not os.path.islink(path):
size_in_mb += os.path.getsize(path)
return size_in_mb / 1e6
return size_in_mb / 1e6
17 changes: 13 additions & 4 deletions tests/test_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,19 @@ class TestWhisperMelSpectrogram(

@classmethod
def setUpClass(cls):
with open(
hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r"
) as f:
n_mels = json.load(f)["num_mel_bins"]
# Try loading config from local path first
config_path = os.path.join(TEST_WHISPER_VERSION, "config.json")
if os.path.exists(config_path):
logger.info(f"Loading config from local path: {config_path}")
with open(config_path, "r") as f:
n_mels = json.load(f)["num_mel_bins"]
else:
# Fall back to downloading from HF hub
logger.info(f"Loading config from Hugging Face hub: {TEST_WHISPER_VERSION}")
with open(
hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r"
) as f:
n_mels = json.load(f)["num_mel_bins"]

logger.info(
f"WhisperMelSpectrogram: n_mels={n_mels} for {TEST_WHISPER_VERSION}"
Expand Down
42 changes: 35 additions & 7 deletions tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from argmaxtools import test_utils as argmaxtools_test_utils
from argmaxtools.utils import get_fastest_device, get_logger
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, WhisperForConditionalGeneration
from transformers.models.whisper import modeling_whisper

from whisperkit import test_utils, text_decoder
Expand Down Expand Up @@ -44,6 +44,39 @@
TEST_TOKEN_TIMESTAMPS = True


def load_whisper_model(model_path: str, torch_dtype=None):
"""Load a Whisper model from either Hugging Face hub or local path
Args:
model_path: Either a Hugging Face model ID or local directory path
torch_dtype: Optional torch dtype to load the model in
Returns:
The loaded Whisper model
"""
logger.info(f"Attempting to load model from: {model_path}")
try:
# First try loading as a local path
if os.path.exists(model_path):
logger.info(f"Loading model from local path: {model_path}")
return WhisperForConditionalGeneration.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch_dtype
)
# If not a valid path, try loading from HF hub
logger.info(f"Loading model from Hugging Face hub: {model_path}")
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype
)
except Exception as e:
raise ValueError(
f"Could not load model from '{model_path}'. "
"Make sure it is either a valid local path or Hugging Face model ID."
) from e


class TestWhisperTextDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -55,12 +88,7 @@ def setUpClass(cls):
cls.test_output_names.pop(cls.test_output_names.index("alignment_heads_weights"))

# Original model
orig_torch_model = (
modeling_whisper.WhisperForConditionalGeneration.from_pretrained(
TEST_WHISPER_VERSION,
torch_dtype=TEST_TORCH_DTYPE,
)
)
orig_torch_model = load_whisper_model(TEST_WHISPER_VERSION, TEST_TORCH_DTYPE)
cls.orig_torch_model = (
orig_torch_model.model.decoder.to(TEST_DEV).to(TEST_TORCH_DTYPE).eval()
)
Expand Down

0 comments on commit 6dd1371

Please sign in to comment.