diff --git a/gen/gen.ml b/gen/gen.ml index 7bb445d5..4d7f8017 100644 --- a/gen/gen.ml +++ b/gen/gen.ml @@ -92,6 +92,7 @@ let excluded_prefixes = ; "_amp_foreach" ; "_nested_tensor" ; "_fused_adam" + ; "_fused_adagrad" ; "sym_" ; "_fused_sgd" ] @@ -590,7 +591,7 @@ let write_cpp funcs filename = let pc s = p out_cpp s in let ph s = p out_h s in pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; - pc "#include \"%s.h\"" (Caml.Filename.basename filename); + pc "#include \"%s.h\"" (Stdlib.Filename.basename filename); pc ""; ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; ph "#include \"torch_api.h\""; diff --git a/src/wrappers/tensor_fallible_generated.rs b/src/wrappers/tensor_fallible_generated.rs index c28a5f40..426d7cae 100644 --- a/src/wrappers/tensor_fallible_generated.rs +++ b/src/wrappers/tensor_fallible_generated.rs @@ -681,6 +681,164 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_batch_norm_no_update>( + &self, + weight: Option, + bias: Option, + running_mean: Option, + running_var: Option, + momentum: f64, + eps: f64, + ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 4]; + unsafe_torch_err!(atg__batch_norm_no_update( + c_tensors.as_mut_ptr(), + self.c_tensor, + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_mean.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_var.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + momentum, + eps + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + )) + } + + pub fn f_internal_batch_norm_no_update_out>( + &self, + out0: &Tensor, + out1: &Tensor, + out2: &Tensor, + out3: &Tensor, + weight: Option, + bias: Option, + running_mean: Option, + running_var: Option, + momentum: f64, + eps: f64, + ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 4]; + unsafe_torch_err!(atg__batch_norm_no_update_out( + c_tensors.as_mut_ptr(), + out0.c_tensor, + out1.c_tensor, + out2.c_tensor, + out3.c_tensor, + self.c_tensor, + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_mean.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_var.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + momentum, + eps + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + )) + } + + pub fn f_internal_batch_norm_with_update>( + &self, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 4]; + unsafe_torch_err!(atg__batch_norm_with_update( + c_tensors.as_mut_ptr(), + self.c_tensor, + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_mean.c_tensor, + running_var.c_tensor, + momentum, + eps + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + )) + } + + pub fn f_internal_batch_norm_with_update_functional>( + &self, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 6]; + unsafe_torch_err!(atg__batch_norm_with_update_functional( + c_tensors.as_mut_ptr(), + self.c_tensor, + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_mean.c_tensor, + running_var.c_tensor, + momentum, + eps + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + Tensor { c_tensor: c_tensors[4] }, + Tensor { c_tensor: c_tensors[5] }, + )) + } + + pub fn f_internal_batch_norm_with_update_out>( + &self, + out: &Tensor, + save_mean: &Tensor, + save_invstd: &Tensor, + reserve: &Tensor, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 4]; + unsafe_torch_err!(atg__batch_norm_with_update_out( + c_tensors.as_mut_ptr(), + out.c_tensor, + save_mean.c_tensor, + save_invstd.c_tensor, + reserve.c_tensor, + self.c_tensor, + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + running_mean.c_tensor, + running_var.c_tensor, + momentum, + eps + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + )) + } + pub fn f_internal_cast_byte(&self, non_blocking: bool) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__cast_byte( @@ -1900,9 +2058,12 @@ impl Tensor { bias_requires_grad: bool, scale: impl Into>, num_splits_key: impl Into>, + window_size: impl Into>, + shared_storage_dqdkdv: bool, ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { let scale = scale.into(); let num_splits_key = num_splits_key.into(); + let window_size = window_size.into(); let mut c_tensors = [std::ptr::null_mut(); 4]; unsafe_torch_err!(atg__efficient_attention_backward( c_tensors.as_mut_ptr(), @@ -1925,7 +2086,10 @@ impl Tensor { scale.unwrap_or(std::f64::NAN), scale.is_none() as i8, num_splits_key.unwrap_or(0i64), - num_splits_key.is_none() as i8 + num_splits_key.is_none() as i8, + window_size.unwrap_or(0i64), + window_size.is_none() as i8, + if shared_storage_dqdkdv { 1 } else { 0 } )); Ok(( Tensor { c_tensor: c_tensors[0] }, @@ -2718,8 +2882,12 @@ impl Tensor { philox_seed: &Tensor, philox_offset: &Tensor, scale: impl Into>, + window_size_left: impl Into>, + window_size_right: impl Into>, ) -> Result<(Tensor, Tensor, Tensor), TchError> { let scale = scale.into(); + let window_size_left = window_size_left.into(); + let window_size_right = window_size_right.into(); let mut c_tensors = [std::ptr::null_mut(); 3]; unsafe_torch_err!(atg__flash_attention_backward( c_tensors.as_mut_ptr(), @@ -2738,7 +2906,11 @@ impl Tensor { philox_seed.c_tensor, philox_offset.c_tensor, scale.unwrap_or(std::f64::NAN), - scale.is_none() as i8 + scale.is_none() as i8, + window_size_left.unwrap_or(0i64), + window_size_left.is_none() as i8, + window_size_right.unwrap_or(0i64), + window_size_right.is_none() as i8 )); Ok(( Tensor { c_tensor: c_tensors[0] }, @@ -3278,68 +3450,6 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } - pub fn f_internal_index_put_impl>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg__index_put_impl( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 }, - if unsafe_ { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - - pub fn f_internal_index_put_impl_>( - &mut self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg__index_put_impl_( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 }, - if unsafe_ { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - - pub fn f_internal_index_put_impl_out>( - &self, - out: &Tensor, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg__index_put_impl_out( - c_tensors.as_mut_ptr(), - out.c_tensor, - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 }, - if unsafe_ { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - pub fn f_internal_indices(&self) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__indices(c_tensors.as_mut_ptr(), self.c_tensor)); @@ -4560,6 +4670,17 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_nested_compute_contiguous_strides_offsets( + nested_size: &Tensor, + ) -> Result<(Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 2]; + unsafe_torch_err!(atg__nested_compute_contiguous_strides_offsets( + c_tensors.as_mut_ptr(), + nested_size.c_tensor + )); + Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) + } + pub fn f_internal_nested_from_padded( padded: &Tensor, cpu_nested_shape_example: &Tensor, @@ -5300,25 +5421,41 @@ impl Tensor { Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) } - pub fn f_internal_scaled_dot_product_cudnn_attention( + pub fn f_internal_scaled_dot_product_cudnn_attention_backward( + grad_out: &Tensor, query: &Tensor, key: &Tensor, value: &Tensor, + out: &Tensor, + logsumexp: &Tensor, + cum_seq_q: &Tensor, + cum_seq_k: &Tensor, + max_q: i64, + max_k: i64, dropout_p: f64, is_causal: bool, - return_debug_mask: bool, + philox_seed: &Tensor, + philox_offset: &Tensor, scale: impl Into>, - ) -> Result<(Tensor, Tensor, Tensor, Tensor), TchError> { + ) -> Result<(Tensor, Tensor, Tensor), TchError> { let scale = scale.into(); - let mut c_tensors = [std::ptr::null_mut(); 4]; - unsafe_torch_err!(atg__scaled_dot_product_cudnn_attention( + let mut c_tensors = [std::ptr::null_mut(); 3]; + unsafe_torch_err!(atg__scaled_dot_product_cudnn_attention_backward( c_tensors.as_mut_ptr(), + grad_out.c_tensor, query.c_tensor, key.c_tensor, value.c_tensor, + out.c_tensor, + logsumexp.c_tensor, + cum_seq_q.c_tensor, + cum_seq_k.c_tensor, + max_q, + max_k, dropout_p, if is_causal { 1 } else { 0 }, - if return_debug_mask { 1 } else { 0 }, + philox_seed.c_tensor, + philox_offset.c_tensor, scale.unwrap_or(std::f64::NAN), scale.is_none() as i8 )); @@ -5326,7 +5463,6 @@ impl Tensor { Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] }, Tensor { c_tensor: c_tensors[2] }, - Tensor { c_tensor: c_tensors[3] }, )) } @@ -5951,6 +6087,30 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_sparse_compressed_tensor_with_dims( + nnz: i64, + dense_dim: i64, + size: impl IntList, + blocksize: impl IntList, + index_dtype: Kind, + options: (Kind, Device), + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__sparse_compressed_tensor_with_dims( + c_tensors.as_mut_ptr(), + nnz, + dense_dim, + size.as_ptr(), + size.len_i32(), + blocksize.as_ptr(), + blocksize.len_i32(), + index_dtype.c_int(), + options.0.c_int(), + options.1.c_int() + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_sparse_coo_tensor_unsafe( indices: &Tensor, values: &Tensor, @@ -6329,6 +6489,51 @@ impl Tensor { Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) } + pub fn f_internal_sparse_semi_structured_addmm( + &self, + mat1: &Tensor, + mat1_meta: &Tensor, + mat2: &Tensor, + out_dtype: impl Into>, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__sparse_semi_structured_addmm( + c_tensors.as_mut_ptr(), + self.c_tensor, + mat1.c_tensor, + mat1_meta.c_tensor, + mat2.c_tensor, + out_dtype.into().map_or(-1, |s| s.c_int()) + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + + pub fn f_internal_sparse_semi_structured_apply( + &self, + thread_masks: &Tensor, + ) -> Result<(Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 2]; + unsafe_torch_err!(atg__sparse_semi_structured_apply( + c_tensors.as_mut_ptr(), + self.c_tensor, + thread_masks.c_tensor + )); + Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) + } + + pub fn f_internal_sparse_semi_structured_apply_dense( + &self, + thread_masks: &Tensor, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__sparse_semi_structured_apply_dense( + c_tensors.as_mut_ptr(), + self.c_tensor, + thread_masks.c_tensor + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_sparse_semi_structured_linear>( &self, weight: &Tensor, @@ -6351,6 +6556,45 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_sparse_semi_structured_mm( + mat1: &Tensor, + mat1_meta: &Tensor, + mat2: &Tensor, + out_dtype: impl Into>, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__sparse_semi_structured_mm( + c_tensors.as_mut_ptr(), + mat1.c_tensor, + mat1_meta.c_tensor, + mat2.c_tensor, + out_dtype.into().map_or(-1, |s| s.c_int()) + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + + pub fn f_internal_sparse_semi_structured_tile( + &self, + algorithm: &str, + use_cutlass: bool, + ) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor), TchError> { + let mut c_tensors = [std::ptr::null_mut(); 5]; + unsafe_torch_err!(atg__sparse_semi_structured_tile( + c_tensors.as_mut_ptr(), + self.c_tensor, + algorithm.as_ptr(), + algorithm.len() as i32, + if use_cutlass { 1 } else { 0 } + )); + Ok(( + Tensor { c_tensor: c_tensors[0] }, + Tensor { c_tensor: c_tensors[1] }, + Tensor { c_tensor: c_tensors[2] }, + Tensor { c_tensor: c_tensors[3] }, + Tensor { c_tensor: c_tensors[4] }, + )) + } + pub fn f_internal_sparse_softmax( &self, dim: i64, @@ -7627,38 +7871,6 @@ impl Tensor { Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) } - pub fn f_internal_unsafe_index>( - &self, - indices: &[Option], - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg__unsafe_index( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32 - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - - pub fn f_internal_unsafe_index_put>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg__unsafe_index_put( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - pub fn f_internal_unsafe_view(&self, size: impl IntList) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__unsafe_view( @@ -11815,9 +12027,9 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } - pub fn f_can_cast(from: Kind, to: Kind) -> Result { + pub fn f_can_cast(from_: Kind, to: Kind) -> Result { let return_; - unsafe_torch_err!(return_ = atg_can_cast(from.c_int(), to.c_int())); + unsafe_torch_err!(return_ = atg_can_cast(from_.c_int(), to.c_int())); Ok(return_ != 0) } @@ -19141,17 +19353,6 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } - pub fn f_index>(&self, indices: &[Option]) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg_index( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32 - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - pub fn f_index_add( &self, dim: i64, @@ -19364,62 +19565,6 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } - pub fn f_index_put>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg_index_put( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - - pub fn f_index_put_>( - &mut self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg_index_put_( - c_tensors.as_mut_ptr(), - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - - pub fn f_index_put_out>( - &self, - out: &Tensor, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg_index_put_out( - c_tensors.as_mut_ptr(), - out.c_tensor, - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32, - values.c_tensor, - if accumulate { 1 } else { 0 } - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - pub fn f_index_reduce( &self, dim: i64, @@ -19534,22 +19679,6 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } - pub fn f_index_tensor_out>( - &self, - out: &Tensor, - indices: &[Option], - ) -> Result { - let mut c_tensors = [std::ptr::null_mut(); 1]; - unsafe_torch_err!(atg_index_tensor_out( - c_tensors.as_mut_ptr(), - out.c_tensor, - self.c_tensor, - ptr_list_opt(indices).as_ptr(), - indices.len() as i32 - )); - Ok(Tensor { c_tensor: c_tensors[0] }) - } - pub fn f_indices(&self) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg_indices(c_tensors.as_mut_ptr(), self.c_tensor)); @@ -24966,6 +25095,7 @@ impl Tensor { stride: impl IntList, dilation: impl IntList, groups: i64, + input_size: impl IntListOption, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg_mkldnn_reorder_conv3d_weight( @@ -24977,7 +25107,9 @@ impl Tensor { stride.len_i32(), dilation.as_ptr(), dilation.len_i32(), - groups + groups, + input_size.as_ptr(), + input_size.len_i32() )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -24989,6 +25121,7 @@ impl Tensor { stride: impl IntList, dilation: impl IntList, groups: i64, + input_size: impl IntListOption, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg_mkldnn_reorder_conv3d_weight_out( @@ -25001,7 +25134,9 @@ impl Tensor { stride.len_i32(), dilation.as_ptr(), dilation.len_i32(), - groups + groups, + input_size.as_ptr(), + input_size.len_i32() )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -29628,6 +29763,26 @@ impl Tensor { Ok(return_ != 0) } + pub fn f_rms_norm>( + &self, + normalized_shape: impl IntList, + weight: Option, + eps: impl Into>, + ) -> Result { + let eps = eps.into(); + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg_rms_norm( + c_tensors.as_mut_ptr(), + self.c_tensor, + normalized_shape.as_ptr(), + normalized_shape.len_i32(), + weight.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + eps.unwrap_or(std::f64::NAN), + eps.is_none() as i8 + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_rnn_relu>( &self, hx: &Tensor, diff --git a/src/wrappers/tensor_generated.rs b/src/wrappers/tensor_generated.rs index 56a2217f..7fcef4dc 100644 --- a/src/wrappers/tensor_generated.rs +++ b/src/wrappers/tensor_generated.rs @@ -312,6 +312,115 @@ impl Tensor { .unwrap() } + pub fn internal_batch_norm_no_update>( + &self, + weight: Option, + bias: Option, + running_mean: Option, + running_var: Option, + momentum: f64, + eps: f64, + ) -> (Tensor, Tensor, Tensor, Tensor) { + self.f_internal_batch_norm_no_update(weight, bias, running_mean, running_var, momentum, eps) + .unwrap() + } + + pub fn internal_batch_norm_no_update_out>( + &self, + out0: &Tensor, + out1: &Tensor, + out2: &Tensor, + out3: &Tensor, + weight: Option, + bias: Option, + running_mean: Option, + running_var: Option, + momentum: f64, + eps: f64, + ) -> (Tensor, Tensor, Tensor, Tensor) { + self.f_internal_batch_norm_no_update_out( + out0, + out1, + out2, + out3, + weight, + bias, + running_mean, + running_var, + momentum, + eps, + ) + .unwrap() + } + + pub fn internal_batch_norm_with_update>( + &self, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> (Tensor, Tensor, Tensor, Tensor) { + self.f_internal_batch_norm_with_update( + weight, + bias, + running_mean, + running_var, + momentum, + eps, + ) + .unwrap() + } + + pub fn internal_batch_norm_with_update_functional>( + &self, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) { + self.f_internal_batch_norm_with_update_functional( + weight, + bias, + running_mean, + running_var, + momentum, + eps, + ) + .unwrap() + } + + pub fn internal_batch_norm_with_update_out>( + &self, + out: &Tensor, + save_mean: &Tensor, + save_invstd: &Tensor, + reserve: &Tensor, + weight: Option, + bias: Option, + running_mean: &Tensor, + running_var: &Tensor, + momentum: f64, + eps: f64, + ) -> (Tensor, Tensor, Tensor, Tensor) { + self.f_internal_batch_norm_with_update_out( + out, + save_mean, + save_invstd, + reserve, + weight, + bias, + running_mean, + running_var, + momentum, + eps, + ) + .unwrap() + } + pub fn internal_cast_byte(&self, non_blocking: bool) -> Tensor { self.f_internal_cast_byte(non_blocking).unwrap() } @@ -1131,6 +1240,8 @@ impl Tensor { bias_requires_grad: bool, scale: impl Into>, num_splits_key: impl Into>, + window_size: impl Into>, + shared_storage_dqdkdv: bool, ) -> (Tensor, Tensor, Tensor, Tensor) { Tensor::f_internal_efficient_attention_backward( grad_out_, @@ -1151,6 +1262,8 @@ impl Tensor { bias_requires_grad, scale, num_splits_key, + window_size, + shared_storage_dqdkdv, ) .unwrap() } @@ -1736,6 +1849,8 @@ impl Tensor { philox_seed: &Tensor, philox_offset: &Tensor, scale: impl Into>, + window_size_left: impl Into>, + window_size_right: impl Into>, ) -> (Tensor, Tensor, Tensor) { Tensor::f_internal_flash_attention_backward( grad_out, @@ -1753,6 +1868,8 @@ impl Tensor { philox_seed, philox_offset, scale, + window_size_left, + window_size_right, ) .unwrap() } @@ -2066,37 +2183,6 @@ impl Tensor { self.f_internal_histogramdd_from_bin_tensors_out(out, bins, weight, density).unwrap() } - pub fn internal_index_put_impl>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Tensor { - self.f_internal_index_put_impl(indices, values, accumulate, unsafe_).unwrap() - } - - pub fn internal_index_put_impl_>( - &mut self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Tensor { - self.f_internal_index_put_impl_(indices, values, accumulate, unsafe_).unwrap() - } - - pub fn internal_index_put_impl_out>( - &self, - out: &Tensor, - indices: &[Option], - values: &Tensor, - accumulate: bool, - unsafe_: bool, - ) -> Tensor { - self.f_internal_index_put_impl_out(out, indices, values, accumulate, unsafe_).unwrap() - } - pub fn internal_indices(&self) -> Tensor { self.f_internal_indices().unwrap() } @@ -2773,6 +2859,12 @@ impl Tensor { self.f_internal_neg_view_copy_out(out).unwrap() } + pub fn internal_nested_compute_contiguous_strides_offsets( + nested_size: &Tensor, + ) -> (Tensor, Tensor) { + Tensor::f_internal_nested_compute_contiguous_strides_offsets(nested_size).unwrap() + } + pub fn internal_nested_from_padded( padded: &Tensor, cpu_nested_shape_example: &Tensor, @@ -3156,22 +3248,38 @@ impl Tensor { .unwrap() } - pub fn internal_scaled_dot_product_cudnn_attention( + pub fn internal_scaled_dot_product_cudnn_attention_backward( + grad_out: &Tensor, query: &Tensor, key: &Tensor, value: &Tensor, + out: &Tensor, + logsumexp: &Tensor, + cum_seq_q: &Tensor, + cum_seq_k: &Tensor, + max_q: i64, + max_k: i64, dropout_p: f64, is_causal: bool, - return_debug_mask: bool, + philox_seed: &Tensor, + philox_offset: &Tensor, scale: impl Into>, - ) -> (Tensor, Tensor, Tensor, Tensor) { - Tensor::f_internal_scaled_dot_product_cudnn_attention( + ) -> (Tensor, Tensor, Tensor) { + Tensor::f_internal_scaled_dot_product_cudnn_attention_backward( + grad_out, query, key, value, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, dropout_p, is_causal, - return_debug_mask, + philox_seed, + philox_offset, scale, ) .unwrap() @@ -3552,6 +3660,25 @@ impl Tensor { .unwrap() } + pub fn internal_sparse_compressed_tensor_with_dims( + nnz: i64, + dense_dim: i64, + size: impl IntList, + blocksize: impl IntList, + index_dtype: Kind, + options: (Kind, Device), + ) -> Tensor { + Tensor::f_internal_sparse_compressed_tensor_with_dims( + nnz, + dense_dim, + size, + blocksize, + index_dtype, + options, + ) + .unwrap() + } + pub fn internal_sparse_coo_tensor_unsafe( indices: &Tensor, values: &Tensor, @@ -3766,6 +3893,24 @@ impl Tensor { self.f_internal_sparse_mm_reduce_impl(other, reduce).unwrap() } + pub fn internal_sparse_semi_structured_addmm( + &self, + mat1: &Tensor, + mat1_meta: &Tensor, + mat2: &Tensor, + out_dtype: impl Into>, + ) -> Tensor { + self.f_internal_sparse_semi_structured_addmm(mat1, mat1_meta, mat2, out_dtype).unwrap() + } + + pub fn internal_sparse_semi_structured_apply(&self, thread_masks: &Tensor) -> (Tensor, Tensor) { + self.f_internal_sparse_semi_structured_apply(thread_masks).unwrap() + } + + pub fn internal_sparse_semi_structured_apply_dense(&self, thread_masks: &Tensor) -> Tensor { + self.f_internal_sparse_semi_structured_apply_dense(thread_masks).unwrap() + } + pub fn internal_sparse_semi_structured_linear>( &self, weight: &Tensor, @@ -3778,6 +3923,23 @@ impl Tensor { .unwrap() } + pub fn internal_sparse_semi_structured_mm( + mat1: &Tensor, + mat1_meta: &Tensor, + mat2: &Tensor, + out_dtype: impl Into>, + ) -> Tensor { + Tensor::f_internal_sparse_semi_structured_mm(mat1, mat1_meta, mat2, out_dtype).unwrap() + } + + pub fn internal_sparse_semi_structured_tile( + &self, + algorithm: &str, + use_cutlass: bool, + ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) { + self.f_internal_sparse_semi_structured_tile(algorithm, use_cutlass).unwrap() + } + pub fn internal_sparse_softmax(&self, dim: i64, half_to_float: bool) -> Tensor { self.f_internal_sparse_softmax(dim, half_to_float).unwrap() } @@ -4386,19 +4548,6 @@ impl Tensor { Tensor::f_internal_unpack_dual(dual, level).unwrap() } - pub fn internal_unsafe_index>(&self, indices: &[Option]) -> Tensor { - self.f_internal_unsafe_index(indices).unwrap() - } - - pub fn internal_unsafe_index_put>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Tensor { - self.f_internal_unsafe_index_put(indices, values, accumulate).unwrap() - } - pub fn internal_unsafe_view(&self, size: impl IntList) -> Tensor { self.f_internal_unsafe_view(size).unwrap() } @@ -6498,8 +6647,8 @@ impl Tensor { self.f_bucketize_tensor_out(out, boundaries, out_int32, right).unwrap() } - pub fn can_cast(from: Kind, to: Kind) -> bool { - Tensor::f_can_cast(from, to).unwrap() + pub fn can_cast(from_: Kind, to: Kind) -> bool { + Tensor::f_can_cast(from_, to).unwrap() } pub fn cartesian_prod>(tensors: &[T]) -> Tensor { @@ -9955,10 +10104,6 @@ impl Tensor { self.f_imag().unwrap() } - pub fn index>(&self, indices: &[Option]) -> Tensor { - self.f_index(indices).unwrap() - } - pub fn index_add(&self, dim: i64, index: &Tensor, source: &Tensor) -> Tensor { self.f_index_add(dim, index, source).unwrap() } @@ -10025,34 +10170,6 @@ impl Tensor { self.f_index_fill_int_tensor_out(out, dim, index, value).unwrap() } - pub fn index_put>( - &self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Tensor { - self.f_index_put(indices, values, accumulate).unwrap() - } - - pub fn index_put_>( - &mut self, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Tensor { - self.f_index_put_(indices, values, accumulate).unwrap() - } - - pub fn index_put_out>( - &self, - out: &Tensor, - indices: &[Option], - values: &Tensor, - accumulate: bool, - ) -> Tensor { - self.f_index_put_out(out, indices, values, accumulate).unwrap() - } - pub fn index_reduce( &self, dim: i64, @@ -10104,14 +10221,6 @@ impl Tensor { self.f_index_select_out(out, dim, index).unwrap() } - pub fn index_tensor_out>( - &self, - out: &Tensor, - indices: &[Option], - ) -> Tensor { - self.f_index_tensor_out(out, indices).unwrap() - } - pub fn indices(&self) -> Tensor { self.f_indices().unwrap() } @@ -12868,8 +12977,9 @@ impl Tensor { stride: impl IntList, dilation: impl IntList, groups: i64, + input_size: impl IntListOption, ) -> Tensor { - self.f_mkldnn_reorder_conv3d_weight(padding, stride, dilation, groups).unwrap() + self.f_mkldnn_reorder_conv3d_weight(padding, stride, dilation, groups, input_size).unwrap() } pub fn mkldnn_reorder_conv3d_weight_out( @@ -12879,8 +12989,10 @@ impl Tensor { stride: impl IntList, dilation: impl IntList, groups: i64, + input_size: impl IntListOption, ) -> Tensor { - self.f_mkldnn_reorder_conv3d_weight_out(out, padding, stride, dilation, groups).unwrap() + self.f_mkldnn_reorder_conv3d_weight_out(out, padding, stride, dilation, groups, input_size) + .unwrap() } pub fn mkldnn_rnn_layer( @@ -15109,6 +15221,15 @@ impl Tensor { self.f_retains_grad().unwrap() } + pub fn rms_norm>( + &self, + normalized_shape: impl IntList, + weight: Option, + eps: impl Into>, + ) -> Tensor { + self.f_rms_norm(normalized_shape, weight, eps).unwrap() + } + pub fn rnn_relu>( &self, hx: &Tensor, diff --git a/torch-sys/libtch/torch_api_generated.cpp b/torch-sys/libtch/torch_api_generated.cpp index b54d6b0c..5c71f116 100644 --- a/torch-sys/libtch/torch_api_generated.cpp +++ b/torch-sys/libtch/torch_api_generated.cpp @@ -368,6 +368,58 @@ void atg__autocast_to_reduced_precision(tensor *out__, tensor self, int cuda_ena ) } +void atg__batch_norm_no_update(tensor *out__, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps) { + PROTECT( + auto outputs__ = torch::_batch_norm_no_update(*input, (weight ? *weight : torch::Tensor()), (bias ? *bias : torch::Tensor()), (running_mean ? *running_mean : torch::Tensor()), (running_var ? *running_var : torch::Tensor()), momentum, eps); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + ) +} + +void atg__batch_norm_no_update_out(tensor *out__, tensor out0, tensor out1, tensor out2, tensor out3, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps) { + PROTECT( + auto outputs__ = torch::_batch_norm_no_update_out(*out0, *out1, *out2, *out3, *input, (weight ? *weight : torch::Tensor()), (bias ? *bias : torch::Tensor()), (running_mean ? *running_mean : torch::Tensor()), (running_var ? *running_var : torch::Tensor()), momentum, eps); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + ) +} + +void atg__batch_norm_with_update(tensor *out__, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps) { + PROTECT( + auto outputs__ = torch::_batch_norm_with_update(*input, (weight ? *weight : torch::Tensor()), (bias ? *bias : torch::Tensor()), *running_mean, *running_var, momentum, eps); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + ) +} + +void atg__batch_norm_with_update_functional(tensor *out__, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps) { + PROTECT( + auto outputs__ = torch::_batch_norm_with_update_functional(*input, (weight ? *weight : torch::Tensor()), (bias ? *bias : torch::Tensor()), *running_mean, *running_var, momentum, eps); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + out__[4] = new torch::Tensor(std::get<4>(outputs__)); + out__[5] = new torch::Tensor(std::get<5>(outputs__)); + ) +} + +void atg__batch_norm_with_update_out(tensor *out__, tensor out, tensor save_mean, tensor save_invstd, tensor reserve, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps) { + PROTECT( + auto outputs__ = torch::_batch_norm_with_update_out(*out, *save_mean, *save_invstd, *reserve, *input, (weight ? *weight : torch::Tensor()), (bias ? *bias : torch::Tensor()), *running_mean, *running_var, momentum, eps); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + ) +} + void atg__cast_byte(tensor *out__, tensor self, int non_blocking) { PROTECT( auto outputs__ = torch::_cast_Byte(*self, (bool)non_blocking); @@ -845,9 +897,9 @@ void atg__dirichlet_grad_out(tensor *out__, tensor out, tensor x, tensor alpha, ) } -void atg__efficient_attention_backward(tensor *out__, tensor grad_out_, tensor query, tensor key, tensor value, tensor bias, tensor out, tensor cu_seqlens_q, tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, tensor logsumexp, double dropout_p, tensor philox_seed, tensor philox_offset, int64_t custom_mask_type, int bias_requires_grad, double scale_v, uint8_t scale_null, int64_t num_splits_key_v, uint8_t num_splits_key_null) { +void atg__efficient_attention_backward(tensor *out__, tensor grad_out_, tensor query, tensor key, tensor value, tensor bias, tensor out, tensor cu_seqlens_q, tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, tensor logsumexp, double dropout_p, tensor philox_seed, tensor philox_offset, int64_t custom_mask_type, int bias_requires_grad, double scale_v, uint8_t scale_null, int64_t num_splits_key_v, uint8_t num_splits_key_null, int64_t window_size_v, uint8_t window_size_null, int shared_storage_dqdkdv) { PROTECT( - auto outputs__ = torch::_efficient_attention_backward(*grad_out_, *query, *key, *value, (bias ? *bias : torch::Tensor()), *out, (cu_seqlens_q ? *cu_seqlens_q : torch::Tensor()), (cu_seqlens_k ? *cu_seqlens_k : torch::Tensor()), max_seqlen_q, max_seqlen_k, *logsumexp, dropout_p, *philox_seed, *philox_offset, custom_mask_type, (bool)bias_requires_grad, scale_null ? c10::nullopt : c10::optional(scale_v), num_splits_key_null ? c10::nullopt : c10::optional(num_splits_key_v)); + auto outputs__ = torch::_efficient_attention_backward(*grad_out_, *query, *key, *value, (bias ? *bias : torch::Tensor()), *out, (cu_seqlens_q ? *cu_seqlens_q : torch::Tensor()), (cu_seqlens_k ? *cu_seqlens_k : torch::Tensor()), max_seqlen_q, max_seqlen_k, *logsumexp, dropout_p, *philox_seed, *philox_offset, custom_mask_type, (bool)bias_requires_grad, scale_null ? c10::nullopt : c10::optional(scale_v), num_splits_key_null ? c10::nullopt : c10::optional(num_splits_key_v), window_size_null ? c10::nullopt : c10::optional(window_size_v), (bool)shared_storage_dqdkdv); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); out__[2] = new torch::Tensor(std::get<2>(outputs__)); @@ -1104,9 +1156,9 @@ void atg__fill_mem_eff_dropout_mask_(tensor *out__, tensor self, double dropout_ ) } -void atg__flash_attention_backward(tensor *out__, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null) { +void atg__flash_attention_backward(tensor *out__, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null, int64_t window_size_left_v, uint8_t window_size_left_null, int64_t window_size_right_v, uint8_t window_size_right_null) { PROTECT( - auto outputs__ = torch::_flash_attention_backward(*grad_out, *query, *key, *value, *out, *logsumexp, *cum_seq_q, *cum_seq_k, max_q, max_k, dropout_p, (bool)is_causal, *philox_seed, *philox_offset, scale_null ? c10::nullopt : c10::optional(scale_v)); + auto outputs__ = torch::_flash_attention_backward(*grad_out, *query, *key, *value, *out, *logsumexp, *cum_seq_q, *cum_seq_k, max_q, max_k, dropout_p, (bool)is_causal, *philox_seed, *philox_offset, scale_null ? c10::nullopt : c10::optional(scale_v), window_size_left_null ? c10::nullopt : c10::optional(window_size_left_v), window_size_right_null ? c10::nullopt : c10::optional(window_size_right_v)); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); out__[2] = new torch::Tensor(std::get<2>(outputs__)); @@ -1317,27 +1369,6 @@ void atg__histogramdd_from_bin_tensors_out(tensor *out__, tensor out, tensor sel ) } -void atg__index_put_impl(tensor *out__, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe) { - PROTECT( - auto outputs__ = torch::_index_put_impl(*self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate, (bool)unsafe); - out__[0] = new torch::Tensor(outputs__); - ) -} - -void atg__index_put_impl_(tensor *out__, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe) { - PROTECT( - auto outputs__ = torch::_index_put_impl_(*self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate, (bool)unsafe); - out__[0] = new torch::Tensor(outputs__); - ) -} - -void atg__index_put_impl_out(tensor *out__, tensor out, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe) { - PROTECT( - auto outputs__ = torch::_index_put_impl_out(*out, *self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate, (bool)unsafe); - out__[0] = new torch::Tensor(outputs__); - ) -} - void atg__indices(tensor *out__, tensor self) { PROTECT( auto outputs__ = self->_indices(); @@ -1851,6 +1882,14 @@ void atg__neg_view_copy_out(tensor *out__, tensor out, tensor self) { ) } +void atg__nested_compute_contiguous_strides_offsets(tensor *out__, tensor nested_size) { + PROTECT( + auto outputs__ = torch::_nested_compute_contiguous_strides_offsets(*nested_size); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + ) +} + void atg__nested_from_padded(tensor *out__, tensor padded, tensor cpu_nested_shape_example, int fuse_transform_0213) { PROTECT( auto outputs__ = torch::_nested_from_padded(*padded, *cpu_nested_shape_example, (bool)fuse_transform_0213); @@ -2226,13 +2265,12 @@ void atg__scaled_dot_product_attention_math(tensor *out__, tensor query, tensor ) } -void atg__scaled_dot_product_cudnn_attention(tensor *out__, tensor query, tensor key, tensor value, double dropout_p, int is_causal, int return_debug_mask, double scale_v, uint8_t scale_null) { +void atg__scaled_dot_product_cudnn_attention_backward(tensor *out__, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null) { PROTECT( - auto outputs__ = torch::_scaled_dot_product_cudnn_attention(*query, *key, *value, dropout_p, (bool)is_causal, (bool)return_debug_mask, scale_null ? c10::nullopt : c10::optional(scale_v)); + auto outputs__ = torch::_scaled_dot_product_cudnn_attention_backward(*grad_out, *query, *key, *value, *out, *logsumexp, *cum_seq_q, *cum_seq_k, max_q, max_k, dropout_p, (bool)is_causal, *philox_seed, *philox_offset, scale_null ? c10::nullopt : c10::optional(scale_v)); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); out__[2] = new torch::Tensor(std::get<2>(outputs__)); - out__[3] = new torch::Tensor(std::get<3>(outputs__)); ) } @@ -2452,6 +2490,13 @@ void atg__sparse_compressed_tensor_unsafe(tensor *out__, tensor compressed_indic ) } +void atg__sparse_compressed_tensor_with_dims(tensor *out__, int64_t nnz, int64_t dense_dim, int64_t *size_data, int size_len, int64_t *blocksize_data, int blocksize_len, int index_dtype, int options_kind, int options_device) { + PROTECT( + auto outputs__ = torch::_sparse_compressed_tensor_with_dims(nnz, dense_dim, torch::IntArrayRef(size_data, size_len), torch::IntArrayRef(blocksize_data, blocksize_len), at::ScalarType(index_dtype), at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind))); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__sparse_coo_tensor_unsafe(tensor *out__, tensor indices, tensor values, int64_t *size_data, int size_len, int options_kind, int options_device, int is_coalesced) { PROTECT( auto outputs__ = torch::_sparse_coo_tensor_unsafe(*indices, *values, torch::IntArrayRef(size_data, size_len), at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)), (bool)is_coalesced); @@ -2600,6 +2645,28 @@ void atg__sparse_mm_reduce_impl(tensor *out__, tensor self, tensor other, char* ) } +void atg__sparse_semi_structured_addmm(tensor *out__, tensor input, tensor mat1, tensor mat1_meta, tensor mat2, int out_dtype) { + PROTECT( + auto outputs__ = torch::_sparse_semi_structured_addmm(*input, *mat1, *mat1_meta, *mat2, out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype))); + out__[0] = new torch::Tensor(outputs__); + ) +} + +void atg__sparse_semi_structured_apply(tensor *out__, tensor input, tensor thread_masks) { + PROTECT( + auto outputs__ = torch::_sparse_semi_structured_apply(*input, *thread_masks); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + ) +} + +void atg__sparse_semi_structured_apply_dense(tensor *out__, tensor input, tensor thread_masks) { + PROTECT( + auto outputs__ = torch::_sparse_semi_structured_apply_dense(*input, *thread_masks); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__sparse_semi_structured_linear(tensor *out__, tensor input, tensor weight, tensor meta, tensor bias, char* activation_ptr, int activation_len, int out_dtype) { PROTECT( auto outputs__ = torch::_sparse_semi_structured_linear(*input, *weight, *meta, (bias ? *bias : torch::Tensor()), std::string(activation_ptr, activation_len), out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype))); @@ -2607,6 +2674,24 @@ void atg__sparse_semi_structured_linear(tensor *out__, tensor input, tensor weig ) } +void atg__sparse_semi_structured_mm(tensor *out__, tensor mat1, tensor mat1_meta, tensor mat2, int out_dtype) { + PROTECT( + auto outputs__ = torch::_sparse_semi_structured_mm(*mat1, *mat1_meta, *mat2, out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype))); + out__[0] = new torch::Tensor(outputs__); + ) +} + +void atg__sparse_semi_structured_tile(tensor *out__, tensor input, char* algorithm_ptr, int algorithm_len, int use_cutlass) { + PROTECT( + auto outputs__ = torch::_sparse_semi_structured_tile(*input, std::string(algorithm_ptr, algorithm_len), (bool)use_cutlass); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + out__[2] = new torch::Tensor(std::get<2>(outputs__)); + out__[3] = new torch::Tensor(std::get<3>(outputs__)); + out__[4] = new torch::Tensor(std::get<4>(outputs__)); + ) +} + void atg__sparse_softmax(tensor *out__, tensor self, int64_t dim, int half_to_float) { PROTECT( auto outputs__ = torch::_sparse_softmax(*self, dim, (bool)half_to_float); @@ -3164,20 +3249,6 @@ void atg__unpack_dual(tensor *out__, tensor dual, int64_t level) { ) } -void atg__unsafe_index(tensor *out__, tensor self, tensor *indices_data, int indices_len) { - PROTECT( - auto outputs__ = torch::_unsafe_index(*self, of_carray_tensor_opt(indices_data, indices_len)); - out__[0] = new torch::Tensor(outputs__); - ) -} - -void atg__unsafe_index_put(tensor *out__, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate) { - PROTECT( - auto outputs__ = torch::_unsafe_index_put(*self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate); - out__[0] = new torch::Tensor(outputs__); - ) -} - void atg__unsafe_view(tensor *out__, tensor self, int64_t *size_data, int size_len) { PROTECT( auto outputs__ = torch::_unsafe_view(*self, torch::IntArrayRef(size_data, size_len)); @@ -5286,9 +5357,9 @@ void atg_bucketize_tensor_out(tensor *out__, tensor out, tensor self, tensor bou ) } -int atg_can_cast(int from, int to) { +int atg_can_cast(int from_, int to) { PROTECT( - return torch::can_cast(at::ScalarType(from), at::ScalarType(to)); + return torch::can_cast(at::ScalarType(from_), at::ScalarType(to)); ) return 0; } @@ -9144,13 +9215,6 @@ void atg_imag(tensor *out__, tensor self) { ) } -void atg_index(tensor *out__, tensor self, tensor *indices_data, int indices_len) { - PROTECT( - auto outputs__ = torch::index(*self, of_carray_tensor_opt(indices_data, indices_len)); - out__[0] = new torch::Tensor(outputs__); - ) -} - void atg_index_add(tensor *out__, tensor self, int64_t dim, tensor index, tensor source) { PROTECT( auto outputs__ = torch::index_add(*self, dim, *index, *source); @@ -9235,27 +9299,6 @@ void atg_index_fill_int_tensor_out(tensor *out__, tensor out, tensor self, int64 ) } -void atg_index_put(tensor *out__, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate) { - PROTECT( - auto outputs__ = torch::index_put(*self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate); - out__[0] = new torch::Tensor(outputs__); - ) -} - -void atg_index_put_(tensor *out__, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate) { - PROTECT( - auto outputs__ = torch::index_put_(*self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate); - out__[0] = new torch::Tensor(outputs__); - ) -} - -void atg_index_put_out(tensor *out__, tensor out, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate) { - PROTECT( - auto outputs__ = torch::index_put_out(*out, *self, of_carray_tensor_opt(indices_data, indices_len), *values, (bool)accumulate); - out__[0] = new torch::Tensor(outputs__); - ) -} - void atg_index_reduce(tensor *out__, tensor self, int64_t dim, tensor index, tensor source, char* reduce_ptr, int reduce_len, int include_self) { PROTECT( auto outputs__ = torch::index_reduce(*self, dim, *index, *source, std::string(reduce_ptr, reduce_len), (bool)include_self); @@ -9298,13 +9341,6 @@ void atg_index_select_out(tensor *out__, tensor out, tensor self, int64_t dim, t ) } -void atg_index_tensor_out(tensor *out__, tensor out, tensor self, tensor *indices_data, int indices_len) { - PROTECT( - auto outputs__ = torch::index_out(*out, *self, of_carray_tensor_opt(indices_data, indices_len)); - out__[0] = new torch::Tensor(outputs__); - ) -} - void atg_indices(tensor *out__, tensor self) { PROTECT( auto outputs__ = self->indices(); @@ -12021,16 +12057,16 @@ void atg_mkldnn_reorder_conv2d_weight_out(tensor *out__, tensor out, tensor self ) } -void atg_mkldnn_reorder_conv3d_weight(tensor *out__, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups) { +void atg_mkldnn_reorder_conv3d_weight(tensor *out__, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len) { PROTECT( - auto outputs__ = torch::mkldnn_reorder_conv3d_weight(*self, torch::IntArrayRef(padding_data, padding_len), torch::IntArrayRef(stride_data, stride_len), torch::IntArrayRef(dilation_data, dilation_len), groups); + auto outputs__ = torch::mkldnn_reorder_conv3d_weight(*self, torch::IntArrayRef(padding_data, padding_len), torch::IntArrayRef(stride_data, stride_len), torch::IntArrayRef(dilation_data, dilation_len), groups, input_size_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(input_size_data, input_size_len))); out__[0] = new torch::Tensor(outputs__); ) } -void atg_mkldnn_reorder_conv3d_weight_out(tensor *out__, tensor out, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups) { +void atg_mkldnn_reorder_conv3d_weight_out(tensor *out__, tensor out, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len) { PROTECT( - auto outputs__ = torch::mkldnn_reorder_conv3d_weight_out(*out, *self, torch::IntArrayRef(padding_data, padding_len), torch::IntArrayRef(stride_data, stride_len), torch::IntArrayRef(dilation_data, dilation_len), groups); + auto outputs__ = torch::mkldnn_reorder_conv3d_weight_out(*out, *self, torch::IntArrayRef(padding_data, padding_len), torch::IntArrayRef(stride_data, stride_len), torch::IntArrayRef(dilation_data, dilation_len), groups, input_size_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(input_size_data, input_size_len))); out__[0] = new torch::Tensor(outputs__); ) } @@ -14353,6 +14389,13 @@ int atg_retains_grad(tensor self) { return 0; } +void atg_rms_norm(tensor *out__, tensor input, int64_t *normalized_shape_data, int normalized_shape_len, tensor weight, double eps_v, uint8_t eps_null) { + PROTECT( + auto outputs__ = torch::rms_norm(*input, torch::IntArrayRef(normalized_shape_data, normalized_shape_len), (weight ? *weight : torch::Tensor()), eps_null ? c10::nullopt : c10::optional(eps_v)); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg_rnn_relu(tensor *out__, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first) { PROTECT( auto outputs__ = torch::rnn_relu(*input, *hx, of_carray_tensor(params_data, params_len), (bool)has_biases, num_layers, dropout, (bool)train, (bool)bidirectional, (bool)batch_first); diff --git a/torch-sys/libtch/torch_api_generated.h b/torch-sys/libtch/torch_api_generated.h index da3bc167..5cec0607 100644 --- a/torch-sys/libtch/torch_api_generated.h +++ b/torch-sys/libtch/torch_api_generated.h @@ -54,6 +54,11 @@ void atg__assert_scalar(scalar self_scalar, char* assert_msg_ptr, int assert_msg void atg__assert_tensor_metadata(tensor a, int64_t *size_data, int size_len, int64_t *stride_data, int stride_len, int dtype); void atg__autocast_to_full_precision(tensor *, tensor self, int cuda_enabled, int cpu_enabled); void atg__autocast_to_reduced_precision(tensor *, tensor self, int cuda_enabled, int cpu_enabled, int cuda_dtype, int cpu_dtype); +void atg__batch_norm_no_update(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps); +void atg__batch_norm_no_update_out(tensor *, tensor out0, tensor out1, tensor out2, tensor out3, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps); +void atg__batch_norm_with_update(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps); +void atg__batch_norm_with_update_functional(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps); +void atg__batch_norm_with_update_out(tensor *, tensor out, tensor save_mean, tensor save_invstd, tensor reserve, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, double momentum, double eps); void atg__cast_byte(tensor *, tensor self, int non_blocking); void atg__cast_char(tensor *, tensor self, int non_blocking); void atg__cast_double(tensor *, tensor self, int non_blocking); @@ -120,7 +125,7 @@ int64_t atg__dimi(tensor self); int64_t atg__dimv(tensor self); void atg__dirichlet_grad(tensor *, tensor x, tensor alpha, tensor total); void atg__dirichlet_grad_out(tensor *, tensor out, tensor x, tensor alpha, tensor total); -void atg__efficient_attention_backward(tensor *, tensor grad_out_, tensor query, tensor key, tensor value, tensor bias, tensor out, tensor cu_seqlens_q, tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, tensor logsumexp, double dropout_p, tensor philox_seed, tensor philox_offset, int64_t custom_mask_type, int bias_requires_grad, double scale_v, uint8_t scale_null, int64_t num_splits_key_v, uint8_t num_splits_key_null); +void atg__efficient_attention_backward(tensor *, tensor grad_out_, tensor query, tensor key, tensor value, tensor bias, tensor out, tensor cu_seqlens_q, tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, tensor logsumexp, double dropout_p, tensor philox_seed, tensor philox_offset, int64_t custom_mask_type, int bias_requires_grad, double scale_v, uint8_t scale_null, int64_t num_splits_key_v, uint8_t num_splits_key_null, int64_t window_size_v, uint8_t window_size_null, int shared_storage_dqdkdv); void atg__efficientzerotensor(tensor *, int64_t *size_data, int size_len, int options_kind, int options_device); void atg__efficientzerotensor_out(tensor *, tensor out, int64_t *size_data, int size_len); void atg__embedding_bag(tensor *, tensor weight, tensor indices, tensor offsets, int scale_grad_by_freq, int64_t mode, int sparse, tensor per_sample_weights, int include_last_offset, int64_t padding_idx); @@ -154,7 +159,7 @@ void atg__fft_c2r_out(tensor *, tensor out, tensor self, int64_t *dim_data, int void atg__fft_r2c(tensor *, tensor self, int64_t *dim_data, int dim_len, int64_t normalization, int onesided); void atg__fft_r2c_out(tensor *, tensor out, tensor self, int64_t *dim_data, int dim_len, int64_t normalization, int onesided); void atg__fill_mem_eff_dropout_mask_(tensor *, tensor self, double dropout_p, int64_t seed, int64_t offset); -void atg__flash_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null); +void atg__flash_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null, int64_t window_size_left_v, uint8_t window_size_left_null, int64_t window_size_right_v, uint8_t window_size_right_null); void atg__foobar(tensor *, tensor self, int arg1, int arg2, int arg3); void atg__foobar_out(tensor *, tensor out, tensor self, int arg1, int arg2, int arg3); void atg__functional_assert_async(tensor *, tensor self, char* assert_msg_ptr, int assert_msg_len, tensor dep_token); @@ -182,9 +187,6 @@ void atg__histogramdd_from_bin_cts(tensor *, tensor self, int64_t *bins_data, in void atg__histogramdd_from_bin_cts_out(tensor *, tensor out, tensor self, int64_t *bins_data, int bins_len, double *range_data, int range_len, tensor weight, int density); void atg__histogramdd_from_bin_tensors(tensor *, tensor self, tensor *bins_data, int bins_len, tensor weight, int density); void atg__histogramdd_from_bin_tensors_out(tensor *, tensor out, tensor self, tensor *bins_data, int bins_len, tensor weight, int density); -void atg__index_put_impl(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe); -void atg__index_put_impl_(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe); -void atg__index_put_impl_out(tensor *, tensor out, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate, int unsafe); void atg__indices(tensor *, tensor self); void atg__indices_copy(tensor *, tensor self); void atg__indices_copy_out(tensor *, tensor out, tensor self); @@ -251,6 +253,7 @@ void atg__native_multi_head_attention_out(tensor *, tensor out0, tensor out1, te void atg__neg_view(tensor *, tensor self); void atg__neg_view_copy(tensor *, tensor self); void atg__neg_view_copy_out(tensor *, tensor out, tensor self); +void atg__nested_compute_contiguous_strides_offsets(tensor *, tensor nested_size); void atg__nested_from_padded(tensor *, tensor padded, tensor cpu_nested_shape_example, int fuse_transform_0213); void atg__nested_from_padded_and_nested_example(tensor *, tensor padded, tensor nt_example); void atg__nested_from_padded_and_nested_example_out(tensor *, tensor out, tensor padded, tensor nt_example); @@ -304,7 +307,7 @@ void atg__sample_dirichlet(tensor *, tensor self); void atg__sample_dirichlet_out(tensor *, tensor out, tensor self); void atg__saturate_weight_to_fp16(tensor *, tensor weight); void atg__scaled_dot_product_attention_math(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null); -void atg__scaled_dot_product_cudnn_attention(tensor *, tensor query, tensor key, tensor value, double dropout_p, int is_causal, int return_debug_mask, double scale_v, uint8_t scale_null); +void atg__scaled_dot_product_cudnn_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_efficient_attention(tensor *, tensor query, tensor key, tensor value, tensor attn_bias, int compute_log_sumexp, double dropout_p, int is_causal, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_flash_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_flash_attention_for_cpu(tensor *, tensor query, tensor key, tensor value, double dropout_p, int is_causal, tensor attn_mask, double scale_v, uint8_t scale_null); @@ -334,6 +337,7 @@ void atg__sparse_broadcast_to_copy_out(tensor *, tensor out, tensor self, int64_ void atg__sparse_bsc_tensor_unsafe(tensor *, tensor ccol_indices, tensor row_indices, tensor values, int64_t *size_data, int size_len, int options_kind, int options_device); void atg__sparse_bsr_tensor_unsafe(tensor *, tensor crow_indices, tensor col_indices, tensor values, int64_t *size_data, int size_len, int options_kind, int options_device); void atg__sparse_compressed_tensor_unsafe(tensor *, tensor compressed_indices, tensor plain_indices, tensor values, int64_t *size_data, int size_len, int options_kind, int options_device); +void atg__sparse_compressed_tensor_with_dims(tensor *, int64_t nnz, int64_t dense_dim, int64_t *size_data, int size_len, int64_t *blocksize_data, int blocksize_len, int index_dtype, int options_kind, int options_device); void atg__sparse_coo_tensor_unsafe(tensor *, tensor indices, tensor values, int64_t *size_data, int size_len, int options_kind, int options_device, int is_coalesced); void atg__sparse_coo_tensor_with_dims(tensor *, int64_t sparse_dim, int64_t dense_dim, int64_t *size_data, int size_len, int options_kind, int options_device); void atg__sparse_coo_tensor_with_dims_and_tensors(tensor *, int64_t sparse_dim, int64_t dense_dim, int64_t *size_data, int size_len, tensor indices, tensor values, int options_kind, int options_device, int is_coalesced); @@ -355,7 +359,12 @@ void atg__sparse_mask_projection_out(tensor *, tensor out, tensor self, tensor m void atg__sparse_mm(tensor *, tensor sparse, tensor dense); void atg__sparse_mm_reduce(tensor *, tensor sparse, tensor dense, char* reduce_ptr, int reduce_len); void atg__sparse_mm_reduce_impl(tensor *, tensor self, tensor other, char* reduce_ptr, int reduce_len); +void atg__sparse_semi_structured_addmm(tensor *, tensor input, tensor mat1, tensor mat1_meta, tensor mat2, int out_dtype); +void atg__sparse_semi_structured_apply(tensor *, tensor input, tensor thread_masks); +void atg__sparse_semi_structured_apply_dense(tensor *, tensor input, tensor thread_masks); void atg__sparse_semi_structured_linear(tensor *, tensor input, tensor weight, tensor meta, tensor bias, char* activation_ptr, int activation_len, int out_dtype); +void atg__sparse_semi_structured_mm(tensor *, tensor mat1, tensor mat1_meta, tensor mat2, int out_dtype); +void atg__sparse_semi_structured_tile(tensor *, tensor input, char* algorithm_ptr, int algorithm_len, int use_cutlass); void atg__sparse_softmax(tensor *, tensor self, int64_t dim, int half_to_float); void atg__sparse_softmax_backward_data(tensor *, tensor grad_output, tensor output, int64_t dim, tensor self); void atg__sparse_softmax_backward_data_out(tensor *, tensor out, tensor grad_output, tensor output, int64_t dim, tensor self); @@ -433,8 +442,6 @@ void atg__unique2(tensor *, tensor self, int sorted, int return_inverse, int ret void atg__unique2_out(tensor *, tensor out0, tensor out1, tensor out2, tensor self, int sorted, int return_inverse, int return_counts); void atg__unique_out(tensor *, tensor out0, tensor out1, tensor self, int sorted, int return_inverse); void atg__unpack_dual(tensor *, tensor dual, int64_t level); -void atg__unsafe_index(tensor *, tensor self, tensor *indices_data, int indices_len); -void atg__unsafe_index_put(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate); void atg__unsafe_view(tensor *, tensor self, int64_t *size_data, int size_len); void atg__unsafe_view_out(tensor *, tensor out, tensor self, int64_t *size_data, int size_len); void atg__upsample_bicubic2d_aa(tensor *, tensor self, int64_t *output_size_data, int output_size_len, int align_corners, double scales_h_v, uint8_t scales_h_null, double scales_w_v, uint8_t scales_w_null); @@ -729,7 +736,7 @@ void atg_bucketize(tensor *, tensor self, tensor boundaries, int out_int32, int void atg_bucketize_scalar(tensor *, scalar self_scalar, tensor boundaries, int out_int32, int right); void atg_bucketize_scalar_out(tensor *, tensor out, scalar self_scalar, tensor boundaries, int out_int32, int right); void atg_bucketize_tensor_out(tensor *, tensor out, tensor self, tensor boundaries, int out_int32, int right); -int atg_can_cast(int from, int to); +int atg_can_cast(int from_, int to); void atg_cartesian_prod(tensor *, tensor *tensors_data, int tensors_len); void atg_cat(tensor *, tensor *tensors_data, int tensors_len, int64_t dim); void atg_cat_out(tensor *, tensor out, tensor *tensors_data, int tensors_len, int64_t dim); @@ -1269,7 +1276,6 @@ void atg_igammac_out(tensor *, tensor out, tensor self, tensor other); void atg_im2col(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *dilation_data, int dilation_len, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len); void atg_im2col_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *dilation_data, int dilation_len, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len); void atg_imag(tensor *, tensor self); -void atg_index(tensor *, tensor self, tensor *indices_data, int indices_len); void atg_index_add(tensor *, tensor self, int64_t dim, tensor index, tensor source); void atg_index_add_(tensor *, tensor self, int64_t dim, tensor index, tensor source); void atg_index_add_out(tensor *, tensor out, tensor self, int64_t dim, tensor index, tensor source); @@ -1282,16 +1288,12 @@ void atg_index_fill_int_scalar_out(tensor *, tensor out, tensor self, int64_t di void atg_index_fill_int_tensor(tensor *, tensor self, int64_t dim, tensor index, tensor value); void atg_index_fill_int_tensor_(tensor *, tensor self, int64_t dim, tensor index, tensor value); void atg_index_fill_int_tensor_out(tensor *, tensor out, tensor self, int64_t dim, tensor index, tensor value); -void atg_index_put(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate); -void atg_index_put_(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate); -void atg_index_put_out(tensor *, tensor out, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate); void atg_index_reduce(tensor *, tensor self, int64_t dim, tensor index, tensor source, char* reduce_ptr, int reduce_len, int include_self); void atg_index_reduce_(tensor *, tensor self, int64_t dim, tensor index, tensor source, char* reduce_ptr, int reduce_len, int include_self); void atg_index_reduce_out(tensor *, tensor out, tensor self, int64_t dim, tensor index, tensor source, char* reduce_ptr, int reduce_len, int include_self); void atg_index_select(tensor *, tensor self, int64_t dim, tensor index); void atg_index_select_backward(tensor *, tensor grad, int64_t *self_sizes_data, int self_sizes_len, int64_t dim, tensor index); void atg_index_select_out(tensor *, tensor out, tensor self, int64_t dim, tensor index); -void atg_index_tensor_out(tensor *, tensor out, tensor self, tensor *indices_data, int indices_len); void atg_indices(tensor *, tensor self); void atg_indices_copy(tensor *, tensor self); void atg_indices_copy_out(tensor *, tensor out, tensor self); @@ -1667,8 +1669,8 @@ void atg_mkldnn_max_pool3d_backward_out(tensor *, tensor out, tensor grad_output void atg_mkldnn_max_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int ceil_mode); void atg_mkldnn_reorder_conv2d_weight(tensor *, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len); void atg_mkldnn_reorder_conv2d_weight_out(tensor *, tensor out, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len); -void atg_mkldnn_reorder_conv3d_weight(tensor *, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups); -void atg_mkldnn_reorder_conv3d_weight_out(tensor *, tensor out, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups); +void atg_mkldnn_reorder_conv3d_weight(tensor *, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len); +void atg_mkldnn_reorder_conv3d_weight_out(tensor *, tensor out, tensor self, int64_t *padding_data, int padding_len, int64_t *stride_data, int stride_len, int64_t *dilation_data, int dilation_len, int64_t groups, int64_t *input_size_data, int input_size_len); void atg_mkldnn_rnn_layer(tensor *, tensor input, tensor weight0, tensor weight1, tensor weight2, tensor weight3, tensor hx_, tensor cx_, int reverse, int64_t *batch_sizes_data, int batch_sizes_len, int64_t mode, int64_t hidden_size, int64_t num_layers, int has_biases, int bidirectional, int batch_first, int train); void atg_mkldnn_rnn_layer_backward(tensor *, tensor input, tensor weight1, tensor weight2, tensor weight3, tensor weight4, tensor hx_, tensor cx_tmp, tensor output, tensor hy_, tensor cy_, tensor grad_output, tensor grad_hy, tensor grad_cy, int reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, int has_biases, int train, int bidirectional, int64_t *batch_sizes_data, int batch_sizes_len, int batch_first, tensor workspace); void atg_mkldnn_rnn_layer_backward_out(tensor *, tensor out0, tensor out1, tensor out2, tensor out3, tensor out4, tensor out5, tensor out6, tensor input, tensor weight1, tensor weight2, tensor weight3, tensor weight4, tensor hx_, tensor cx_tmp, tensor output, tensor hy_, tensor cy_, tensor grad_output, tensor grad_hy, tensor grad_cy, int reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, int has_biases, int train, int bidirectional, int64_t *batch_sizes_data, int batch_sizes_len, int batch_first, tensor workspace); @@ -1993,6 +1995,7 @@ void atg_resize_out(tensor *, tensor out, tensor self, int64_t *size_data, int s void atg_resolve_conj(tensor *, tensor self); void atg_resolve_neg(tensor *, tensor self); int atg_retains_grad(tensor self); +void atg_rms_norm(tensor *, tensor input, int64_t *normalized_shape_data, int normalized_shape_len, tensor weight, double eps_v, uint8_t eps_null); void atg_rnn_relu(tensor *, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first); void atg_rnn_relu_cell(tensor *, tensor input, tensor hx, tensor w_ih, tensor w_hh, tensor b_ih, tensor b_hh); void atg_rnn_relu_data(tensor *, tensor data, tensor batch_sizes, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional); diff --git a/torch-sys/src/c_generated.rs b/torch-sys/src/c_generated.rs index a84e9808..5b8a8b8a 100644 --- a/torch-sys/src/c_generated.rs +++ b/torch-sys/src/c_generated.rs @@ -257,6 +257,64 @@ extern "C" { cuda_dtype_: c_int, cpu_dtype_: c_int, ); + pub fn atg__batch_norm_no_update( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + weight_: *mut C_tensor, + bias_: *mut C_tensor, + running_mean_: *mut C_tensor, + running_var_: *mut C_tensor, + momentum_: f64, + eps_: f64, + ); + pub fn atg__batch_norm_no_update_out( + out__: *mut *mut C_tensor, + out0_: *mut C_tensor, + out1_: *mut C_tensor, + out2_: *mut C_tensor, + out3_: *mut C_tensor, + input_: *mut C_tensor, + weight_: *mut C_tensor, + bias_: *mut C_tensor, + running_mean_: *mut C_tensor, + running_var_: *mut C_tensor, + momentum_: f64, + eps_: f64, + ); + pub fn atg__batch_norm_with_update( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + weight_: *mut C_tensor, + bias_: *mut C_tensor, + running_mean_: *mut C_tensor, + running_var_: *mut C_tensor, + momentum_: f64, + eps_: f64, + ); + pub fn atg__batch_norm_with_update_functional( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + weight_: *mut C_tensor, + bias_: *mut C_tensor, + running_mean_: *mut C_tensor, + running_var_: *mut C_tensor, + momentum_: f64, + eps_: f64, + ); + pub fn atg__batch_norm_with_update_out( + out__: *mut *mut C_tensor, + out_: *mut C_tensor, + save_mean_: *mut C_tensor, + save_invstd_: *mut C_tensor, + reserve_: *mut C_tensor, + input_: *mut C_tensor, + weight_: *mut C_tensor, + bias_: *mut C_tensor, + running_mean_: *mut C_tensor, + running_var_: *mut C_tensor, + momentum_: f64, + eps_: f64, + ); pub fn atg__cast_byte(out__: *mut *mut C_tensor, self_: *mut C_tensor, non_blocking_: c_int); pub fn atg__cast_char(out__: *mut *mut C_tensor, self_: *mut C_tensor, non_blocking_: c_int); pub fn atg__cast_double(out__: *mut *mut C_tensor, self_: *mut C_tensor, non_blocking_: c_int); @@ -767,6 +825,9 @@ extern "C" { scale_null: i8, num_splits_key_v: i64, num_splits_key_null: i8, + window_size_v: i64, + window_size_null: i8, + shared_storage_dqdkdv_: c_int, ); pub fn atg__efficientzerotensor( out__: *mut *mut C_tensor, @@ -1113,6 +1174,10 @@ extern "C" { philox_offset_: *mut C_tensor, scale_v: f64, scale_null: i8, + window_size_left_v: i64, + window_size_left_null: i8, + window_size_right_v: i64, + window_size_right_null: i8, ); pub fn atg__foobar( out__: *mut *mut C_tensor, @@ -1333,34 +1398,6 @@ extern "C" { weight_: *mut C_tensor, density_: c_int, ); - pub fn atg__index_put_impl( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - unsafe_: c_int, - ); - pub fn atg__index_put_impl_( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - unsafe_: c_int, - ); - pub fn atg__index_put_impl_out( - out__: *mut *mut C_tensor, - out_: *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - unsafe_: c_int, - ); pub fn atg__indices(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg__indices_copy(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg__indices_copy_out( @@ -1852,6 +1889,10 @@ extern "C" { out_: *mut C_tensor, self_: *mut C_tensor, ); + pub fn atg__nested_compute_contiguous_strides_offsets( + out__: *mut *mut C_tensor, + nested_size_: *mut C_tensor, + ); pub fn atg__nested_from_padded( out__: *mut *mut C_tensor, padded_: *mut C_tensor, @@ -2155,14 +2196,22 @@ extern "C" { scale_v: f64, scale_null: i8, ); - pub fn atg__scaled_dot_product_cudnn_attention( + pub fn atg__scaled_dot_product_cudnn_attention_backward( out__: *mut *mut C_tensor, + grad_out_: *mut C_tensor, query_: *mut C_tensor, key_: *mut C_tensor, value_: *mut C_tensor, + out_: *mut C_tensor, + logsumexp_: *mut C_tensor, + cum_seq_q_: *mut C_tensor, + cum_seq_k_: *mut C_tensor, + max_q_: i64, + max_k_: i64, dropout_p_: f64, is_causal_: c_int, - return_debug_mask_: c_int, + philox_seed_: *mut C_tensor, + philox_offset_: *mut C_tensor, scale_v: f64, scale_null: i8, ); @@ -2436,6 +2485,18 @@ extern "C" { options_kind: c_int, options_device: c_int, ); + pub fn atg__sparse_compressed_tensor_with_dims( + out__: *mut *mut C_tensor, + nnz_: i64, + dense_dim_: i64, + size_data: *const i64, + size_len: c_int, + blocksize_data: *const i64, + blocksize_len: c_int, + index_dtype_: c_int, + options_kind: c_int, + options_device: c_int, + ); pub fn atg__sparse_coo_tensor_unsafe( out__: *mut *mut C_tensor, indices_: *mut C_tensor, @@ -2602,6 +2663,24 @@ extern "C" { reduce_ptr: *const u8, reduce_len: c_int, ); + pub fn atg__sparse_semi_structured_addmm( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + mat1_: *mut C_tensor, + mat1_meta_: *mut C_tensor, + mat2_: *mut C_tensor, + out_dtype_: c_int, + ); + pub fn atg__sparse_semi_structured_apply( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + thread_masks_: *mut C_tensor, + ); + pub fn atg__sparse_semi_structured_apply_dense( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + thread_masks_: *mut C_tensor, + ); pub fn atg__sparse_semi_structured_linear( out__: *mut *mut C_tensor, input_: *mut C_tensor, @@ -2612,6 +2691,20 @@ extern "C" { activation_len: c_int, out_dtype_: c_int, ); + pub fn atg__sparse_semi_structured_mm( + out__: *mut *mut C_tensor, + mat1_: *mut C_tensor, + mat1_meta_: *mut C_tensor, + mat2_: *mut C_tensor, + out_dtype_: c_int, + ); + pub fn atg__sparse_semi_structured_tile( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + algorithm_ptr: *const u8, + algorithm_len: c_int, + use_cutlass_: c_int, + ); pub fn atg__sparse_softmax( out__: *mut *mut C_tensor, self_: *mut C_tensor, @@ -3144,20 +3237,6 @@ extern "C" { return_inverse_: c_int, ); pub fn atg__unpack_dual(out__: *mut *mut C_tensor, dual_: *mut C_tensor, level_: i64); - pub fn atg__unsafe_index( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - ); - pub fn atg__unsafe_index_put( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - ); pub fn atg__unsafe_view( out__: *mut *mut C_tensor, self_: *mut C_tensor, @@ -4855,7 +4934,7 @@ extern "C" { out_int32_: c_int, right_: c_int, ); - pub fn atg_can_cast(from_: c_int, to_: c_int) -> c_int; + pub fn atg_can_cast(from__: c_int, to_: c_int) -> c_int; pub fn atg_cartesian_prod( out__: *mut *mut C_tensor, tensors_data: *const *mut C_tensor, @@ -7901,12 +7980,6 @@ extern "C" { stride_len: c_int, ); pub fn atg_imag(out__: *mut *mut C_tensor, self_: *mut C_tensor); - pub fn atg_index( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - ); pub fn atg_index_add( out__: *mut *mut C_tensor, self_: *mut C_tensor, @@ -7995,31 +8068,6 @@ extern "C" { index_: *mut C_tensor, value_: *mut C_tensor, ); - pub fn atg_index_put( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - ); - pub fn atg_index_put_( - out__: *mut *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - ); - pub fn atg_index_put_out( - out__: *mut *mut C_tensor, - out_: *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - values_: *mut C_tensor, - accumulate_: c_int, - ); pub fn atg_index_reduce( out__: *mut *mut C_tensor, self_: *mut C_tensor, @@ -8072,13 +8120,6 @@ extern "C" { dim_: i64, index_: *mut C_tensor, ); - pub fn atg_index_tensor_out( - out__: *mut *mut C_tensor, - out_: *mut C_tensor, - self_: *mut C_tensor, - indices_data: *const *mut C_tensor, - indices_len: c_int, - ); pub fn atg_indices(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_indices_copy(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_indices_copy_out( @@ -10303,6 +10344,8 @@ extern "C" { dilation_data: *const i64, dilation_len: c_int, groups_: i64, + input_size_data: *const i64, + input_size_len: c_int, ); pub fn atg_mkldnn_reorder_conv3d_weight_out( out__: *mut *mut C_tensor, @@ -10315,6 +10358,8 @@ extern "C" { dilation_data: *const i64, dilation_len: c_int, groups_: i64, + input_size_data: *const i64, + input_size_len: c_int, ); pub fn atg_mkldnn_rnn_layer( out__: *mut *mut C_tensor, @@ -12247,6 +12292,15 @@ extern "C" { pub fn atg_resolve_conj(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_resolve_neg(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_retains_grad(self_: *mut C_tensor) -> c_int; + pub fn atg_rms_norm( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + normalized_shape_data: *const i64, + normalized_shape_len: c_int, + weight_: *mut C_tensor, + eps_v: f64, + eps_null: i8, + ); pub fn atg_rnn_relu( out__: *mut *mut C_tensor, input_: *mut C_tensor,