diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0d836c7fd4..2731456d43 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + /// similar to [`from_pth`] but requires a `state_key`. + pub fn from_pth_with_state>( + p: P, + dtype: DType, + state_key: &str, + dev: &Device, + ) -> Result { + let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. ///