Skip to content

Commit

Permalink
fix an issue while quantizing llama models
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseCarlosGarcia95 committed Dec 30, 2024
1 parent 677d03a commit 1ada7fa
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 47 deletions.
4 changes: 4 additions & 0 deletions candle-core/src/quantized/gguf_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ impl Value {
Self::F32(v)
}

pub fn from_string(v: String) -> Self {
Self::String(v)
}

pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
Expand Down
38 changes: 22 additions & 16 deletions candle-examples/examples/quantized-bitnet/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,7 @@ fn main() -> anyhow::Result<()> {
for prompt_index in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => {
if args.which.is_falcon() {
format!("<|user|>\n{prompt}\n<|assistant|>")
} else if args.which.is_llama() {
format!(
"{prompt}"
)
} else {
prompt.clone()
}
prompt.clone()
}
Prompt::Interactive | Prompt::Chat => {
let is_interactive = matches!(prompt, Prompt::Interactive);
Expand All @@ -324,14 +316,15 @@ fn main() -> anyhow::Result<()> {
}
}
};

print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
let token = token.to_string().replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
Expand Down Expand Up @@ -383,12 +376,24 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}

let eos_token = match args.which {
Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>",
Which::Llama3_8b1_58 => "<|eot_id|>",
let eos_tokens = match args.which {
Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => {
vec!["<|endoftext|>"]
}
Which::Llama3_8b1_58 => {
vec!["<|eot_id|>"]
}
};

let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();

let eos_tokens: Vec<u32> = eos_tokens
.iter()
.map(|token| {
*tos.tokenizer()
.get_vocab(true)
.get(*token)
.unwrap_or_else(|| panic!("EoS token not found: {}", token))
})
.collect();

let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
Expand All @@ -413,7 +418,8 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {

if eos_tokens.contains(&next_token) {
break;
};
}
Expand Down
26 changes: 14 additions & 12 deletions candle-transformers/src/models/quantized_llama_bitnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ struct BitQMatMul {
weight_scale: Tensor,
}



impl BitQMatMul {
fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
Expand All @@ -65,23 +63,27 @@ impl BitQMatMul {
}

pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> {
let last_dim = x.rank().saturating_sub(1);
let max_abs = x.abs()?.max_keepdim(last_dim)?;

let clamped = max_abs.clamp(1e-5, f32::INFINITY)?;
let scale = (127.0 / &clamped)?;

let scaled_rounded = x.broadcast_mul(&scale)?.round()?.clamp(-128f32, 127f32)?;

let target_dim = x.rank().saturating_sub(1);

let max_abs = x.abs()?.max_keepdim(target_dim)?;

let scale = (127.0/ &max_abs)?;

let scaled_rounded = x
.broadcast_mul(&scale)?
.round()?
.clamp(-128f32, 127f32)?;


Ok((scaled_rounded, scale))
}

fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (x, xscale) = self.activation_quant(x)?;
let _enter = self.span.enter();
let scale = self.weight_scale.broadcast_mul(&xscale)?;
self.inner.forward(&x)?
.broadcast_div(&self.weight_scale)?
.broadcast_div(&xscale)
.broadcast_div(&scale)
}
}

Expand Down
92 changes: 73 additions & 19 deletions tensor-tools/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,10 @@ fn run_print(
println!("==== {name} ====");
match content.tensor(&mut file, name, device) {
Ok(tensor) => {
let dtype = tensor.dtype();

let tensor = tensor.dequantize(device)?;
println!("{tensor}")
println!("{tensor} {dtype:?}")
}
Err(e) => {
eprintln!("error: {e}");
Expand Down Expand Up @@ -438,12 +440,36 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result<Tensor> {
Ok(unpacked_tensor)
}

use core::num;
use std::collections::HashMap;
use std::fs::File;
use std::path::PathBuf;
use rayon::prelude::*;
use serde_json::Value;

fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option<usize>) -> Result<Tensor> {
let n_head = match n_head_kv {
Some(n_head_kv) if n_head != n_head_kv => n_head_kv,
_ => n_head,
};

let shape = weights.shape();
let shape0 = shape.dims()[0];
if shape0 % (n_head * 2) != 0 {
candle::bail!("weights.shape()[0] is not divisible by (n_head * 2)");
}

let mut new_shape = vec![n_head, 2, shape0 / (n_head * 2)];
new_shape.extend_from_slice(&shape.dims()[1..]);

let permuted = weights
.reshape(new_shape)?
.transpose(1, 2)?
.reshape(weights.shape())?;

Ok(permuted)
}

fn run_quantize_safetensors(
in_files: &[PathBuf],
out_file: PathBuf,
Expand All @@ -459,6 +485,34 @@ fn run_quantize_safetensors(

let mut qtensors = Vec::new();

let mut num_attention_heads = 0;
let mut num_key_value_heads = 0;
let mut architecture = String::new();


let gguf_metadata = if let Some(metadata_file) = metadata_file {
let metadata_content = std::fs::read_to_string(metadata_file)?;
let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap();

num_attention_heads = metadata["num_attention_heads"].as_u64().unwrap();
num_key_value_heads = metadata["num_key_value_heads"].as_u64().unwrap();
architecture = metadata["model_type"].as_str().unwrap().to_string();

vec![
("llama.attention.head_count", gguf_file::Value::from_u32(num_attention_heads as u32)),
("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)),
("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)),
("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)),
("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)),
("llama.rope.dimension_count", gguf_file::Value::from_u32(
(metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32),
)),
("llama.rope.freq_base", gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32)),
("general.architecture", gguf_file::Value::from_string(architecture.clone())),
]
} else {
vec![]
};
for in_file in in_files {
if let Some(metadata) = &metadata_file {
if Some(in_file) == Some(metadata) {
Expand Down Expand Up @@ -492,6 +546,24 @@ fn run_quantize_safetensors(
local_dtype = bq.clone().unwrap().dtype();
}
}

if name == "lm_head.weight" {
local_dtype = GgmlDType::Q6K;
}

// apply transformations to the tensors, based on the architecture
match architecture.as_str() {
"llama" => {
if name.ends_with("self_attn.q_proj.weight") {
tensor = permute(&tensor, num_attention_heads as usize, Some(num_attention_heads as usize))?;
}
if name.ends_with("self_attn.k_proj.weight") {
tensor = permute(&tensor, num_attention_heads as usize, Some(num_key_value_heads as usize))?;
}
}
_ => {}
}

println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
QTensor::quantize(&tensor, local_dtype)?
Expand Down Expand Up @@ -530,24 +602,6 @@ fn run_quantize_safetensors(
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();

let gguf_metadata = if let Some(metadata_file) = metadata_file {
let metadata_content = std::fs::read_to_string(metadata_file)?;
let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap();

vec![
("llama.attention.head_count", gguf_file::Value::from_u32(metadata["num_attention_heads"].as_u64().unwrap() as u32)),
("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)),
("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)),
("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)),
("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)),
("llama.rope.dimension_count", gguf_file::Value::from_u32(
(metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32),
)),
]
} else {
vec![]
};

gguf_file::write(&mut out_file, &gguf_metadata, &qtensors)?;
Ok(())
}
Expand Down

0 comments on commit 1ada7fa

Please sign in to comment.