diff --git a/jllm/cat2hf.py b/jllm/cat2hf.py index 6c79191..ae7f80c 100644 --- a/jllm/cat2hf.py +++ b/jllm/cat2hf.py @@ -36,10 +36,10 @@ def func(meta_data): model_file =f"model-{pipe_rank:05d}-of-"+f"{num_stages:05d}.safetensors" for k in tqdm.tqdm(keys): - if "o_proj" in k or "down_proj" in k in k: + if "o_proj" in k or "down_proj" in k or 'attn.proj.weight' in k or 'mlp.fc2.weight' in k or 'mlp.2.weight' in k : state_dict[k] = torch.cat([p[1].pop(k) for p in pts],1) - elif "lm_head" in k or "gate_proj" in k or "up_proj" in k or "embed_tokens" in k\ - or "q_proj" in k or "k_proj" in k or "v_proj" in k: + elif "lm_head" in k or "gate_proj" in k or "up_proj" in k or "embed_tokens" in k or "q_proj" in k \ + or "k_proj" in k or "v_proj" in k or 'attn.qkv' in k or 'mlp.fc1' in k or 'mlp.0' in k: state_dict[k] = torch.cat([p[1].pop(k) for p in pts]) else: state_dict[k] = pts[0][1].pop(k) diff --git a/jllm/data/utils.cpython-311-x86_64-linux-gnu.so b/jllm/data/utils.cpython-311-x86_64-linux-gnu.so index 2a6a65c..85652ac 100644 Binary files a/jllm/data/utils.cpython-311-x86_64-linux-gnu.so and b/jllm/data/utils.cpython-311-x86_64-linux-gnu.so differ diff --git a/jllm/data/utils.cpython-39-aarch64-linux-gnu.so b/jllm/data/utils.cpython-39-aarch64-linux-gnu.so index e19face..426f843 100644 Binary files a/jllm/data/utils.cpython-39-aarch64-linux-gnu.so and b/jllm/data/utils.cpython-39-aarch64-linux-gnu.so differ diff --git a/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-311-x86_64-linux-gnu.so b/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-311-x86_64-linux-gnu.so index 7e61032..f5c7013 100644 Binary files a/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-311-x86_64-linux-gnu.so and b/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-311-x86_64-linux-gnu.so differ diff --git a/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-39-aarch64-linux-gnu.so b/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-39-aarch64-linux-gnu.so index be733a0..1d3bb6a 100644 Binary files a/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-39-aarch64-linux-gnu.so and b/jllm/model/qwen2_vl/parallel_qwen2_vl.cpython-39-aarch64-linux-gnu.so differ diff --git a/setup.py b/setup.py index a2d8af3..5f4a4ae 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import io project_name = "jllm" -version = "4.0.2" +version = "4.0.3" setuptools.setup( name=project_name,