Skip to content

Commit

Permalink
Provide a method to allow PTH files with state maps to be loaded. (hu…
Browse files Browse the repository at this point in the history
…ggingface#2639)

* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
  • Loading branch information
zachcp authored and imihalcea committed Nov 26, 2024
1 parent 83723a4 commit 41f3113
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}

/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
/// similar to [`from_pth`] but requires a `state_key`.
pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
p: P,
dtype: DType,
state_key: &str,
dev: &Device,
) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///
Expand Down

0 comments on commit 41f3113

Please sign in to comment.