From 6ead8dc3975e8fc1dc36637c7e0b0e901f449b73 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 22 Mar 2024 11:09:20 +0100 Subject: [PATCH 01/21] added one_hot --- Cargo.lock | 241 +++++++++++++++------------------- src/core/graph/autodiff.rs | 10 ++ src/core/graph/compile.rs | 24 +++- src/core/graph/consteval.rs | 251 ++++++++++++++++++++++-------------- src/core/graph/context.rs | 23 +++- src/core/graph/math.rs | 31 +++++ src/core/graph/operation.rs | 2 + 7 files changed, 335 insertions(+), 247 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 271d195..d364904 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,9 +30,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -45,9 +45,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "95d8e92cac0961e91dbd517496b00f7e9b92363dbe6d42c3198268323798860c" dependencies = [ "addr2line", "cc", @@ -94,9 +94,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "block-buffer" @@ -136,10 +136,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.88" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" dependencies = [ + "jobserver", "libc", ] @@ -209,6 +210,12 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -268,9 +275,9 @@ checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "erased-serde" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388979d208a049ffdfb22fa33b9c81942215b940910bccfe258caeb25d125cb3" +checksum = "2b73807008a3c7f171cc40312f37d95ef0396e048b5848d775f54b1a4dd4a0d3" dependencies = [ "serde", ] @@ -282,7 +289,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -351,7 +358,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -419,9 +426,19 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "half" -version = "1.8.2" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + +[[package]] +name = "half" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" +dependencies = [ + "cfg-if", + "crunchy", +] [[package]] name = "heck" @@ -444,7 +461,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -477,6 +494,15 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jobserver" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +dependencies = [ + "libc", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -497,12 +523,12 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-targets", ] [[package]] @@ -519,9 +545,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "memchr" @@ -657,9 +683,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -745,9 +771,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -774,15 +800,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys", ] [[package]] @@ -812,7 +838,7 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" dependencies = [ - "half", + "half 1.8.3", "serde", ] @@ -824,7 +850,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -886,27 +912,27 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "strum" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "723b93e8addf9aa965ebe2d11da6d7540fa2283fcea14b3371ff055f7ba13f5f" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" [[package]] name = "strum_macros" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18" +checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -928,9 +954,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.50" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -939,22 +965,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -984,9 +1010,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "typetag" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43148481c7b66502c48f35b8eef38b6ccdc7a9f04bd4cc294226d901ccc9bc7" +checksum = "661d18414ec032a49ece2d56eee03636e43c4e8d577047ab334c0ba892e29aaf" dependencies = [ "erased-serde", "inventory", @@ -997,13 +1023,13 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291db8a81af4840c10d636e047cac67664e343be44e24dfdbd1492df9a5d3390" +checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -1060,145 +1086,80 @@ dependencies = [ "rustix", ] -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.3", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", + "windows-targets", ] [[package]] name = "windows-targets" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.3", - "windows_aarch64_msvc 0.52.3", - "windows_i686_gnu 0.52.3", - "windows_i686_msvc 0.52.3", - "windows_x86_64_gnu 0.52.3", - "windows_x86_64_gnullvm 0.52.3", - "windows_x86_64_msvc 0.52.3", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" -version = "0.48.5" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" -version = "0.52.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" -version = "0.52.3" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "xla" version = "0.1.6" -source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#4293f038ee1b8466da4bb0e0859413d6ea8a1aca" +source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#b436bb3fed5bbf4e60d406e51070673f0c0de1f0" dependencies = [ "bindgen", "cc", + "half 2.4.0", "libc", "num-derive", "num-traits", diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 9cd796c..ed9503e 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -91,6 +91,16 @@ impl Context { } } + Operation::OneHot(node) => { + if self.gradient_is_dependent(node, dependent_node) { + return Err(ContextError::NonDifferentiableOpError( + self.nodes[dependent_node].callsite.clone(), + )); + } else { + continue; + } + } + Operation::TypeCast(a, _) => { if self.gradient_is_dependent(node, dependent_node) { return Err(ContextError::NonDifferentiableOpError( diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index e6f6d51..a5ccbf2 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -107,8 +107,8 @@ impl Context { if covered_ops.contains(dependent_op) { continue; } - let node = &self.nodes[*dependent_op]; - match node.operation { + let this_node = &self.nodes[*dependent_op]; + match this_node.operation { Operation::Parameter(_) => { unreachable!("Parameters can't depend on other nodes") } @@ -327,7 +327,25 @@ impl Context { covered_ops.insert(*dependent_op); } } - Operation::ReduceMax { node, dim, keepdims } => { + Operation::OneHot(node) => { + if unda_xla_map.contains_key(&node) + && xla_op_slotmap.contains_key(unda_xla_map[&node]) + { + let n_classes = this_node.shape.sizes[1]; + let dtype = this_node.dtype; + let xla_op = xla_op_slotmap[unda_xla_map[&node]] + .one_hot(n_classes as i64, dtype)?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } + Operation::ReduceMax { + node, + dim, + keepdims, + } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], keepdims)?; diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index bee16fa..01aea56 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -5,12 +5,9 @@ use xla::ElementType; use super::*; impl Context { - - fn collect_deps( - &self, - node: NodeIdentifier, - ) -> Vec { - self.dependent_nodes[&node].iter() + fn collect_deps(&self, node: NodeIdentifier) -> Vec { + self.dependent_nodes[&node] + .iter() .map(|node| node.clone()) .collect::>() } @@ -18,10 +15,10 @@ impl Context { fn replace_index( &mut self, to_remove: NodeIdentifier, - rep_with: NodeIdentifier - ) -> Result { + rep_with: NodeIdentifier, + ) -> Result { let mut changed = false; - + let deps = self.collect_deps(to_remove); for dep_node in deps { @@ -37,8 +34,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Add(a, rep_with); changed = true; } - - }, + } Operation::Sub(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Sub(rep_with, rep_with); @@ -50,8 +46,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Sub(a, rep_with); changed = true; } - - }, + } Operation::Mul(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Mul(rep_with, rep_with); @@ -63,8 +58,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Mul(a, rep_with); changed = true; } - - }, + } Operation::GreaterThan(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::GreaterThan(rep_with, rep_with); @@ -76,11 +70,12 @@ impl Context { self.nodes[dep_node].operation = Operation::GreaterThan(a, rep_with); changed = true; } - }, + } Operation::GreaterThanEq(a, b) => { if a == b { - self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, rep_with); + self.nodes[dep_node].operation = + Operation::GreaterThanEq(rep_with, rep_with); changed = true; } else if a == to_remove { self.nodes[dep_node].operation = Operation::GreaterThanEq(rep_with, b); @@ -89,7 +84,7 @@ impl Context { self.nodes[dep_node].operation = Operation::GreaterThanEq(a, rep_with); changed = true; } - }, + } Operation::Equal(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::Equal(rep_with, rep_with); @@ -101,7 +96,7 @@ impl Context { self.nodes[dep_node].operation = Operation::Equal(a, rep_with); changed = true; } - }, + } Operation::NotEqual(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::NotEqual(rep_with, rep_with); @@ -113,7 +108,7 @@ impl Context { self.nodes[dep_node].operation = Operation::NotEqual(a, rep_with); changed = true; } - }, + } Operation::LessThan(a, b) => { if a == b { self.nodes[dep_node].operation = Operation::LessThan(rep_with, rep_with); @@ -125,7 +120,7 @@ impl Context { self.nodes[dep_node].operation = Operation::LessThan(a, rep_with); changed = true; } - }, + } Operation::LessThanEq(a, b) => { if a == b { @@ -138,61 +133,110 @@ impl Context { self.nodes[dep_node].operation = Operation::LessThanEq(a, rep_with); changed = true; } - }, - Operation::Constant(_) - | Operation::Parameter(_) => { + } + Operation::Constant(_) | Operation::Parameter(_) => { unreachable!("Constants or Parameters cannot depend on nodes"); - }, + } Operation::StopGradient(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::StopGradient(rep_with); changed = true; } - }, + } Operation::Neg(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::Neg(rep_with); changed = true; } - }, + } Operation::ZerosLike(a) => { if a == to_remove { self.nodes[dep_node].operation = Operation::ZerosLike(rep_with); changed = true; } - }, + } + Operation::OneHot(node) => { + if node == to_remove { + self.nodes[dep_node].operation = Operation::OneHot(rep_with); + changed = true; + } + } Operation::TypeCast(_, t) => { changed = true; self.nodes[dep_node].operation = Operation::TypeCast(rep_with, t) - }, - Operation::Select { pred, on_true, on_false } => { + } + Operation::Select { + pred, + on_true, + on_false, + } => { if pred == to_remove { if pred == on_true { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true: rep_with, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true: rep_with, + on_false, + } } else if pred == on_false { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true, on_false: rep_with } - } else { - self.nodes[dep_node].operation = Operation::Select { pred: rep_with, on_true, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true, + on_false: rep_with, + } + } else { + self.nodes[dep_node].operation = Operation::Select { + pred: rep_with, + on_true, + on_false, + } } changed = true; } else if on_true == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::Select { pred, on_true: rep_with, on_false } + self.nodes[dep_node].operation = Operation::Select { + pred, + on_true: rep_with, + on_false, + } } else if on_false == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::Select { pred, on_true, on_false: rep_with } + self.nodes[dep_node].operation = Operation::Select { + pred, + on_true, + on_false: rep_with, + } } - }, - Operation::ReduceMax { node, dim, keepdims } => { + } + Operation::ReduceMax { + node, + dim, + keepdims, + } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::ReduceMax { node: rep_with, dim, keepdims } + self.nodes[dep_node].operation = Operation::ReduceMax { + node: rep_with, + dim, + keepdims, + } } - }, - Operation::SliceInDim { node, start, stop, stride, dim } => { + } + Operation::SliceInDim { + node, + start, + stop, + stride, + dim, + } => { if node == to_remove { changed = true; - self.nodes[dep_node].operation = Operation::SliceInDim { node: rep_with, start, stop, stride, dim } + self.nodes[dep_node].operation = Operation::SliceInDim { + node: rep_with, + start, + stop, + stride, + dim, + } } } } @@ -208,7 +252,7 @@ impl Context { &mut self, input: NodeIdentifier, modification_limit: usize, - ) -> Result { + ) -> Result { if modification_limit == 0 { return Ok(true); } @@ -224,30 +268,28 @@ impl Context { continue; } match self.nodes[node_id].operation { - Operation::Add(a, b) - | Operation::Sub(a, b) => { - if self.nodes[a].is_zero()? { - self.replace_index(node_id, b)?; - modifications += 1; - changed = true; - - } else if self.nodes[b].is_zero()? { - self.replace_index(node_id, a)?; - modifications += 1; - changed = true; - } - //Enqueue the dependent nodes to check both of them for constant - //mul/adding + Operation::Add(a, b) | Operation::Sub(a, b) => { + if self.nodes[a].is_zero()? { + self.replace_index(node_id, b)?; + modifications += 1; + changed = true; + } else if self.nodes[b].is_zero()? { + self.replace_index(node_id, a)?; + modifications += 1; + changed = true; + } + //Enqueue the dependent nodes to check both of them for constant + //mul/adding - //TODO: Once we create a new Node based on the constant propegation, - //use insert_with_key to 'replace existant node' - if self.nodes.get(a).unwrap().is_const().is_none() { - to_visit.push(a); - } - if self.nodes.get(b).unwrap().is_const().is_none() { - to_visit.push(b); - } - }, + //TODO: Once we create a new Node based on the constant propegation, + //use insert_with_key to 'replace existant node' + if self.nodes.get(a).unwrap().is_const().is_none() { + to_visit.push(a); + } + if self.nodes.get(b).unwrap().is_const().is_none() { + to_visit.push(b); + } + } Operation::Mul(a, b) => { if self.nodes[a].is_zero()? { self.replace_index(node_id, a)?; @@ -256,7 +298,8 @@ impl Context { } if let Some(literal) = self.nodes[a].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -270,14 +313,15 @@ impl Context { changed = true; } } - if self.nodes[b].is_zero()?{ + if self.nodes[b].is_zero()? { self.replace_index(node_id, b)?; modifications += 1; changed = true; } if let Some(literal) = self.nodes[b].is_const() { //Check for mul by 1 - let floating_literal: Vec = literal.convert(xla::PrimitiveType::F32)?.to_vec()?; + let floating_literal: Vec = + literal.convert(xla::PrimitiveType::F32)?.to_vec()?; let mut all_one = true; floating_literal.iter().for_each(|elem| { if *elem != 1f32 { @@ -293,43 +337,43 @@ impl Context { } if let None = self.nodes[a].is_const() { to_visit.push(a); - } + } if let None = self.nodes[b].is_const() { to_visit.push(b); } - - }, + } Operation::Neg(a) => { if let None = self.nodes[a].is_const() { to_visit.push(a); } } Operation::GreaterThan(a, b) - | Operation::GreaterThanEq(a, b) - | Operation::LessThan(a, b) - | Operation::LessThanEq(a, b) - | Operation::Equal(a, b) - | Operation::NotEqual(a, b) - => { - - if let None = self.nodes[a].is_const() { - to_visit.push(a); - } - - if let None = self.nodes[b].is_const() { - to_visit.push(b); - } + | Operation::GreaterThanEq(a, b) + | Operation::LessThan(a, b) + | Operation::LessThanEq(a, b) + | Operation::Equal(a, b) + | Operation::NotEqual(a, b) => { + if let None = self.nodes[a].is_const() { + to_visit.push(a); + } - }, + if let None = self.nodes[b].is_const() { + to_visit.push(b); + } + } Operation::StopGradient(a) - | Operation::TypeCast(a, _) - | Operation::ZerosLike(a) - => { + | Operation::TypeCast(a, _) + | Operation::ZerosLike(a) + | Operation::OneHot(a) => { if let None = self.nodes[a].is_const() { to_visit.push(a); } - }, - Operation::Select { pred, on_true, on_false } => { + } + Operation::Select { + pred, + on_true, + on_false, + } => { if let None = self.nodes[pred].is_const() { to_visit.push(pred) } @@ -339,19 +383,28 @@ impl Context { if let None = self.nodes[on_false].is_const() { to_visit.push(on_false) } - }, - Operation::SliceInDim { node, start, stop, stride, dim } => { + } + Operation::SliceInDim { + node, + start, + stop, + stride, + dim, + } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } - }, - Operation::ReduceMax { node, dim, keepdims } => { + } + Operation::ReduceMax { + node, + dim, + keepdims, + } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } - }, - Operation::Constant(_) - | Operation::Parameter(_) => {} + } + Operation::Constant(_) | Operation::Parameter(_) => {} } visitied.insert(node_id); } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 2aa8880..05cc111 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -52,6 +52,12 @@ pub enum ContextError { #[error("Type is not differentiable, differentiable types are F16, Bf16, F32, F64, C64, C128")] NonDifferentiableTypeError(Callsite), + + #[error("Expected integral type")] + IntegralTypeError(Callsite), + + #[error("Expected tensor of rank {0}, got {1}")] + RankError(usize, usize, Callsite), } pub type Result = std::result::Result; @@ -120,17 +126,24 @@ impl Context { dim, } => format!( "SliceInDim ({}) {} {} {} {}", - self.to_string(node), start, stop, stride, dim + self.to_string(node), + start, + stop, + stride, + dim ), Operation::ZerosLike(node) => format!("ZerosLike {}", self.to_string(node)), + Operation::OneHot(node) => format!( + "OneHot ({}) {} {}", + self.to_string(node), + input_node.shape.sizes[1], + input_node.dtype + ), Operation::ReduceMax { node, dim, keepdims, - } => format!( - "SliceInDim {} {} {}", - self.to_string(node), dim, keepdims - ), + } => format!("SliceInDim {} {} {}", self.to_string(node), dim, keepdims), } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index eafd673..2ee9c37 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -541,4 +541,35 @@ impl Context { .push(node_id); node_id } + + pub fn one_hot(&mut self, sparse_label_vector: NodeIdentifier, n_classes: usize, dtype: xla::ElementType) -> Result { + if self.nodes[sparse_label_vector].shape.ndims() != 1 { + return Err(ContextError::RankError(1, self.nodes[sparse_label_vector].shape.ndims(), callsite!(1))) + } + let label_len = self.nodes[sparse_label_vector].shape.sizes[0]; + + let converted = match self.nodes[sparse_label_vector].dtype { + xla::ElementType::S64 => sparse_label_vector, + xla::ElementType::U8 + | xla::ElementType::S8 + | xla::ElementType::U16 + | xla::ElementType::S16 + | xla::ElementType::U32 + | xla::ElementType::S32 + | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), + _ => return Err(ContextError::IntegralTypeError(callsite!(1))) + }; + + let node_id = self.nodes.insert(Node { + callsite: callsite!(1), + shape: Shape::from([label_len, n_classes as u16]), + operation: Operation::OneHot(converted), + dtype: dtype, + }); + self.dependent_nodes + .entry(converted) + .or_insert(Vec::new()) + .push(node_id); + Ok(node_id) + } } diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index 405c714..554e1d5 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -28,6 +28,8 @@ pub enum Operation { ZerosLike(NodeIdentifier), ReduceMax{ node: NodeIdentifier, dim: i64, keepdims: bool }, + + OneHot(NodeIdentifier), } impl Display for Operation { From f91b4b88c7652527f6f6f8b9bea2a5e9cad006e5 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 22 Mar 2024 11:29:56 +0100 Subject: [PATCH 02/21] working on accuracy --- src/core/graph/autodiff.rs | 3 ++- src/core/graph/compile.rs | 14 ++++++++++++++ src/core/graph/consteval.rs | 24 ++++++++++++++++++++++++ src/core/graph/context.rs | 12 +++++++++++- src/core/graph/math.rs | 34 ++++++++++++++++++++++++++++++++++ src/core/graph/operation.rs | 1 + 6 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index ed9503e..6d443c8 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -91,7 +91,8 @@ impl Context { } } - Operation::OneHot(node) => { + Operation::OneHot(node) + | Operation::ReduceArgmax { node: node, dim: _, keepdims: _ } => { if self.gradient_is_dependent(node, dependent_node) { return Err(ContextError::NonDifferentiableOpError( self.nodes[dependent_node].callsite.clone(), diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index a5ccbf2..989b67e 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -355,6 +355,20 @@ impl Context { covered_ops.insert(*dependent_op); } } + Operation::ReduceArgmax { + node, + dim, + keepdims, + } => { + if xla_op_slotmap.contains_key(unda_xla_map[&node]) { + let xla_op = + xla_op_slotmap[unda_xla_map[&node]].reduce_argmax(dim, keepdims)?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } } } } diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index 01aea56..5e2014f 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -221,6 +221,20 @@ impl Context { } } } + Operation::ReduceArgmax { + node, + dim, + keepdims, + } => { + if node == to_remove { + changed = true; + self.nodes[dep_node].operation = Operation::ReduceArgmax { + node: rep_with, + dim, + keepdims, + } + } + } Operation::SliceInDim { node, start, @@ -405,6 +419,16 @@ impl Context { } } Operation::Constant(_) | Operation::Parameter(_) => {} + Operation::ReduceArgmax { + node, + dim, + keepdims, + } => { + if let None = self.nodes[node].is_const() { + to_visit.push(node); + } + } + Operation::Constant(_) | Operation::Parameter(_) => {} } visitied.insert(node_id); } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 05cc111..7c6d7ab 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -143,7 +143,17 @@ impl Context { node, dim, keepdims, - } => format!("SliceInDim {} {} {}", self.to_string(node), dim, keepdims), + } => format!("ReduceMax {} {} {}", self.to_string(node), dim, keepdims), + Operation::ReduceMax { + node, + dim, + keepdims, + } => format!("ReduceMax {} {} {}", self.to_string(node), dim, keepdims), + Operation::ReduceArgmax { + node, + dim, + keepdims, + } => format!("ReduceArgmax {} {} {}", self.to_string(node), dim, keepdims), } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 2ee9c37..bcff17e 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -542,6 +542,40 @@ impl Context { node_id } + pub fn reduce_argmax(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> NodeIdentifier { + let a = a.into(); + let mut s = Shape::new(); + for d in (0..self.nodes[a].shape.ndims()).rev() { + if d as i64 == dim && keepdims { + s.sizes.push(1) + } else { + s.sizes.push(self.nodes[a].shape.sizes[d]) + } + } + let node_id = self.nodes.insert(Node { + callsite: callsite!(1), + shape: s, + operation: Operation::ReduceArgmax { + node: a, + dim: dim, + keepdims: keepdims, + }, + dtype: xla::ElementType::S64, + }); + self.dependent_nodes + .entry(a) + .or_insert(Vec::new()) + .push(node_id); + node_id + } + + pub fn accuracy(&mut self, dense_predictions: NodeIdentifier, sparse_label_vector: NodeIdentifier) -> Result { + let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false); + let compare = self.eq(sparse_predictions, sparse_label_vector)?; + let converted = self.type_cast(comparse, xla::ElementType::F32); + self.reduce_mean(converted, 0, false); + } + pub fn one_hot(&mut self, sparse_label_vector: NodeIdentifier, n_classes: usize, dtype: xla::ElementType) -> Result { if self.nodes[sparse_label_vector].shape.ndims() != 1 { return Err(ContextError::RankError(1, self.nodes[sparse_label_vector].shape.ndims(), callsite!(1))) diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index 554e1d5..9f7f3bb 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -28,6 +28,7 @@ pub enum Operation { ZerosLike(NodeIdentifier), ReduceMax{ node: NodeIdentifier, dim: i64, keepdims: bool }, + ReduceArgmax{ node: NodeIdentifier, dim: i64, keepdims: bool }, OneHot(NodeIdentifier), } From e1fc10d491855037c9a547568e1d2b6a863f57e2 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 22 Mar 2024 11:53:02 +0100 Subject: [PATCH 03/21] building again --- Cargo.lock | 30 ----- src/core/graph/compile.rs | 9 +- src/core/graph/consteval.rs | 6 +- src/core/graph/context.rs | 4 + src/core/graph/math.rs | 250 +++++++++++------------------------- 5 files changed, 87 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c2ab0b..3652e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,11 +358,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", -<<<<<<< HEAD - "syn 2.0.53", -======= "syn 2.0.52", ->>>>>>> master ] [[package]] @@ -854,11 +850,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", -<<<<<<< HEAD - "syn 2.0.53", -======= "syn 2.0.52", ->>>>>>> master ] [[package]] @@ -940,11 +932,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", -<<<<<<< HEAD - "syn 2.0.53", -======= "syn 2.0.52", ->>>>>>> master ] [[package]] @@ -966,15 +954,9 @@ dependencies = [ [[package]] name = "syn" -<<<<<<< HEAD -version = "2.0.53" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" -======= version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" ->>>>>>> master dependencies = [ "proc-macro2", "quote", @@ -998,11 +980,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", -<<<<<<< HEAD - "syn 2.0.53", -======= "syn 2.0.52", ->>>>>>> master ] [[package]] @@ -1051,11 +1029,7 @@ checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" dependencies = [ "proc-macro2", "quote", -<<<<<<< HEAD - "syn 2.0.53", -======= "syn 2.0.52", ->>>>>>> master ] [[package]] @@ -1182,11 +1156,7 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "xla" version = "0.1.6" -<<<<<<< HEAD source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#b436bb3fed5bbf4e60d406e51070673f0c0de1f0" -======= -source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#e90f9250df62aea0ec7efe15c8040fa5f7258f6b" ->>>>>>> master dependencies = [ "bindgen", "cc", diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index c97a143..7194e8e 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -108,11 +108,11 @@ impl Context { if covered_ops.contains(dependent_op) { continue; } - let node = &self.nodes[*dependent_op]; + let this_node = &self.nodes[*dependent_op]; //TODO: Clone here is not great, we could & the node operation //or come up with a better way of storing the Vec that Transpose //uses(that's what causes the borrow checker error if we dont clone) - match node.operation.clone() { + match this_node.operation.clone() { Operation::Parameter(_) => { unreachable!("Parameters can't depend on other nodes") } @@ -390,7 +390,7 @@ impl Context { && xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reshape( - self.nodes[*dependent_op] + this_node .shape .sizes .iter() @@ -505,11 +505,10 @@ impl Context { Operation::ReduceArgmax { node, dim, - keepdims, } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = - xla_op_slotmap[unda_xla_map[&node]].reduce_argmax(dim, keepdims)?; + xla_op_slotmap[unda_xla_map[&node]].reduce_argmax(dim, false)?; let xla_id = xla_op_slotmap.insert(xla_op); unda_xla_map.insert(*dependent_op, xla_id); unda_op_queue.push_back(*dependent_op); diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index 8649be5..0b321a2 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -492,7 +492,8 @@ impl Context { Operation::StopGradient(a) | Operation::TypeCast(a, _) | Operation::Reshape(a) - | Operation::ZerosLike(a) => { + | Operation::ZerosLike(a) + | Operation::OneHot(a) => { if let None = self.nodes[a].is_const() { to_visit.push(a); } @@ -530,7 +531,8 @@ impl Context { } Operation::ReduceMax { node, dim: _ } | Operation::ReduceSum { node, dim: _ } - | Operation::ReduceMean { node, dim:_ } => { + | Operation::ReduceMean { node, dim:_ } + | Operation:: ReduceArgmax { node, dim: _ } => { if let None = self.nodes[node].is_const() { to_visit.push(node); } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 733d4f9..c79df8a 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -158,6 +158,10 @@ impl Context { node, dim, } => format!("ReduceMax {} {}", self.to_string(node), dim), + Operation::ReduceArgmax { + node, + dim, + } => format!("ReduceArgmax {} {}", self.to_string(node), dim), Operation::ReduceSum { node, dim, diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index e871942..530c61b 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -30,15 +30,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -71,15 +65,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -95,10 +83,7 @@ impl Context { dtype: self.nodes[a].dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); node_id } @@ -110,14 +95,10 @@ impl Context { dtype: self.nodes[a].dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } - pub fn exp(&mut self, a: NodeIdentifier) -> Result { let node = Node { callsite: callsite!(1), @@ -126,14 +107,11 @@ impl Context { dtype: self.nodes[a].dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } - pub fn pow(&mut self, a: NodeIdentifier, b : NodeIdentifier) -> Result { + pub fn pow(&mut self, a: NodeIdentifier, b: NodeIdentifier) -> Result { let node_a = &self.nodes[a]; let node_b = &self.nodes[b]; @@ -158,15 +136,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -220,15 +192,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -236,7 +202,6 @@ impl Context { } } - pub fn mul(&mut self, a: NodeIdentifier, b: NodeIdentifier) -> Result { let node_a = &self.nodes[a]; let node_b = &self.nodes[b]; @@ -262,15 +227,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -303,15 +262,9 @@ impl Context { dtype: node_a.dtype, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -344,14 +297,8 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); Ok(node_id) } } @@ -383,15 +330,9 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -424,15 +365,9 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -465,15 +400,9 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -506,15 +435,9 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -547,15 +470,9 @@ impl Context { dtype: xla::ElementType::Pred, }; let node_id = self.nodes.insert(node); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); if a != b { - self.dependent_nodes - .entry(b) - .or_default() - .push(node_id); + self.dependent_nodes.entry(b).or_default().push(node_id); } Ok(node_id) } @@ -630,10 +547,7 @@ impl Context { operation: Operation::TypeCast(a, dtype), dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); node_id } @@ -654,10 +568,7 @@ impl Context { operation: Operation::Reshape(a), dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } } @@ -681,10 +592,7 @@ impl Context { operation: Operation::Transpose(a, index_perms_deref), dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } @@ -717,10 +625,7 @@ impl Context { }, dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } @@ -749,10 +654,7 @@ impl Context { }, dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); Ok(node_id) } @@ -763,14 +665,16 @@ impl Context { operation: Operation::ZerosLike(a), dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); node_id } - fn maybe_keepdims(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> Result { + fn maybe_keepdims( + &mut self, + a: NodeIdentifier, + dim: i64, + keepdims: bool, + ) -> Result { if keepdims { let mut s_keepdim = self.nodes[a].shape.clone(); s_keepdim.sizes.insert(dim as usize, 1u32); @@ -795,16 +699,10 @@ impl Context { let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, - operation: Operation::ReduceMax { - node: a, - dim, - }, + operation: Operation::ReduceMax { node: a, dim }, dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); self.maybe_keepdims(node_id, dim, keepdims) } @@ -823,16 +721,10 @@ impl Context { let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, - operation: Operation::ReduceSum { - node: a, - dim, - }, + operation: Operation::ReduceSum { node: a, dim }, dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); self.maybe_keepdims(node_id, dim, keepdims) } @@ -851,20 +743,19 @@ impl Context { let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, - operation: Operation::ReduceMean { - node: a, - dim, - }, + operation: Operation::ReduceMean { node: a, dim }, dtype: self.nodes[a].dtype, }); - self.dependent_nodes - .entry(a) - .or_default() - .push(node_id); + self.dependent_nodes.entry(a).or_default().push(node_id); self.maybe_keepdims(node_id, dim, keepdims) } - pub fn reduce_argmax(&mut self, a: NodeIdentifier, dim: i64, keepdims: bool) -> NodeIdentifier { + pub fn reduce_argmax( + &mut self, + a: NodeIdentifier, + dim: i64, + keepdims: bool, + ) -> Result { let a = a.into(); let mut s = Shape::new(); for d in (0..self.nodes[a].shape.ndims()).rev() { @@ -877,30 +768,39 @@ impl Context { let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, - operation: Operation::ReduceArgmax { - node: a, - dim: dim, - keepdims: keepdims, - }, + operation: Operation::ReduceArgmax { node: a, dim: dim }, dtype: xla::ElementType::S64, }); self.dependent_nodes .entry(a) .or_insert(Vec::new()) .push(node_id); - node_id + self.maybe_keepdims(node_id, dim, keepdims) } - pub fn accuracy(&mut self, dense_predictions: NodeIdentifier, sparse_label_vector: NodeIdentifier) -> Result { - let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false); + pub fn accuracy( + &mut self, + dense_predictions: NodeIdentifier, + sparse_label_vector: NodeIdentifier, + ) -> Result { + let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false)?; let compare = self.eq(sparse_predictions, sparse_label_vector)?; - let converted = self.type_cast(comparse, xla::ElementType::F32); - self.reduce_mean(converted, 0, false); + let converted = self.type_cast(compare, xla::ElementType::F32); + self.reduce_mean(converted, 0, false) } - pub fn one_hot(&mut self, sparse_label_vector: NodeIdentifier, n_classes: usize, dtype: xla::ElementType) -> Result { + pub fn one_hot( + &mut self, + sparse_label_vector: NodeIdentifier, + n_classes: usize, + dtype: xla::ElementType, + ) -> Result { if self.nodes[sparse_label_vector].shape.ndims() != 1 { - return Err(ContextError::RankError(1, self.nodes[sparse_label_vector].shape.ndims(), callsite!(1))) + return Err(ContextError::RankError( + 1, + self.nodes[sparse_label_vector].shape.ndims(), + callsite!(1), + )); } let label_len = self.nodes[sparse_label_vector].shape.sizes[0]; @@ -913,12 +813,12 @@ impl Context { | xla::ElementType::U32 | xla::ElementType::S32 | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), - _ => return Err(ContextError::IntegralTypeError(callsite!(1))) + _ => return Err(ContextError::IntegralTypeError(callsite!(1))), }; let node_id = self.nodes.insert(Node { callsite: callsite!(1), - shape: Shape::from([label_len, n_classes as u16]), + shape: Shape::from([label_len, n_classes as u32]), operation: Operation::OneHot(converted), dtype: dtype, }); From 06626f66684c525f499aef83b72c93fec07acd25 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 22 Mar 2024 15:42:37 +0100 Subject: [PATCH 04/21] started working on mnist example --- Cargo.lock | 15 ++-- Cargo.toml | 1 + examples/mnist.rs | 31 +++++-- examples/mnist_xla.rs | 166 +++++++++++++++++++++++++++++++++++++ src/core/graph/autodiff.rs | 2 +- src/core/graph/context.rs | 11 ++- src/core/graph/math.rs | 41 ++++++++- src/core/graph/node.rs | 10 +-- 8 files changed, 253 insertions(+), 24 deletions(-) create mode 100644 examples/mnist_xla.rs diff --git a/Cargo.lock b/Cargo.lock index 3652e79..640c20b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,7 +358,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -850,7 +850,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -932,7 +932,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -954,9 +954,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -980,7 +980,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1029,7 +1029,7 @@ checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1037,6 +1037,7 @@ name = "unda" version = "0.2.2" dependencies = [ "backtrace", + "byteorder", "csv", "futures", "half 2.4.0", diff --git a/Cargo.toml b/Cargo.toml index 9ce3623..e7f95dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ strum_macros = "0.26" xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "dev" } thiserror = "1" half = "2.4.0" +byteorder = "1.5" [features] default = ["util"] diff --git a/examples/mnist.rs b/examples/mnist.rs index d5307e9..16f6a04 100644 --- a/examples/mnist.rs +++ b/examples/mnist.rs @@ -1,8 +1,18 @@ -use unda::{core::{data::{input::Input, matrix::Matrix}, network::Sequential, layer::{layers::{LayerTypes, InputTypes}, methods::{activations::Activations, errors::ErrorTypes}}}, util::{mnist::MnistEntry, categorical::to_categorical}}; +use unda::{ + core::{ + data::{input::Input, matrix::Matrix}, + layer::{ + layers::{InputTypes, LayerTypes}, + methods::{activations::Activations, errors::ErrorTypes}, + }, + network::Sequential, + }, + util::{categorical::to_categorical, mnist::MnistEntry}, +}; fn main() { let mut inputs: Vec<&dyn Input> = vec![]; - + let mut true_outputs: Vec> = vec![]; let inputs_undyn: Vec; @@ -13,7 +23,7 @@ fn main() { println!("Done Generating MNIST"); let outputs: Vec> = to_categorical(outputs_uncat); - for i in 0..600{ + for i in 0..600 { inputs.push(&inputs_undyn[i]); true_outputs.push(outputs[i].clone()); } @@ -28,8 +38,17 @@ fn main() { network.compile(); - network.fit(&inputs, &true_outputs, 1, ErrorTypes::CategoricalCrossEntropy); - for i in 0..10{ - println!("predicted: {:?} \n\nactual: {:?}\n\n\n", network.predict(inputs[i]), true_outputs[i]); + network.fit( + &inputs, + &true_outputs, + 1, + ErrorTypes::CategoricalCrossEntropy, + ); + for i in 0..10 { + println!( + "predicted: {:?} \n\nactual: {:?}\n\n\n", + network.predict(inputs[i]), + true_outputs[i] + ); } } diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs new file mode 100644 index 0000000..e3373b9 --- /dev/null +++ b/examples/mnist_xla.rs @@ -0,0 +1,166 @@ +use unda::core::graph::*; +use xla::{ElementType::*, PjRtLoadedExecutable}; +use std::io; +use std::fs::File; +use std::os::unix::fs::*; +use byteorder::{LittleEndian, ReadBytesExt}; + +// ABSTRACT API REQUIREMENT 1: Automatic Layer Construction +// We should have functions like this which, for a given layer type, +// automatically resolve shapes and dtypes and construct nodes for +// the parameters and outputs of a layer. In the final version, +// a function like this should also take an "initialization" parameter +// and run random initialization for the weights and bias. +fn dense( + model: &mut Context, + input_node: NodeIdentifier, + out_size: u32, + name: &str, +) -> Result<(NodeIdentifier, (NodeIdentifier, NodeIdentifier))> { + let shape = model.nodes[input_node].shape; + let last_dim = shape.sizes[shape.ndims() - 1]; + let dtype = model.nodes[input_node].dtype; + + let weights_shape = Shape::from([out_size, last_dim]); + let mut weights_name = name.to_owned(); + weights_name.push_str("_weights"); + let weights = model.parameter(weights_name, weights_shape, dtype)?; + + let mut bias_shape = Shape::new(); + bias_shape.sizes.push(out_size); + for i in 0..(shape.ndims() - 1) { + bias_shape.sizes.push(1u32); + } + let mut bias_name = name.to_owned(); + bias_name.push_str("_bias"); + let bias = model.parameter(bias_name, bias_shape, dtype)?; + + let matmul_node = model.matmul(weights, input_node)?; + let dense_node = model.add(matmul_node, bias)?; + + Ok((dense_node, (weights, bias))) +} + +fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result { + let mut model = Context::new(); + + // ABSTRACT API REQUIREMENT 2: Dynamic Batching + // In this example, the batch size is hardcoded to 100. + // This is fine because MNIST has exactly 60K training + // and 10K testing examples. It should not be generally + // assumed that the batch size divides the dataset size. + // Abstract model objects must be optimized for a specific + // batch size but be willing to take any. One simple way to + // achieve this would be simply having constant batch size + // but masking (via multiplication with a binary vector) + // the loss on "empty" examples. + let image_input = model.parameter("image_input", [100, 28 * 28], U8)?; + let image_fp = model.type_cast(image_input, F32); + // MNIST bytes range from 0 to 255, neural network only wants to see 0 to 1 + let scale = model.scalar(1f32 / 255f32, F32)?; + let image_rescaled = model.div(image_fp, scale)?; + + let sparse_labels = model.parameter("sparse_labels", [100], S64)?; + let one_hot_labels = model.one_hot(sparse_labels, 10, F32)?; + + let (d1, (w1, b1)) = dense(&mut model, image_rescaled, 784, "layer1")?; + let d1_activation = model.relu(d1)?; + let (d2, (w2, b2)) = dense(&mut model, d1_activation, 256, "layer2")?; + let d2_activation = model.relu(d2)?; + let (d3, (w3, b3)) = dense(&mut model, d2_activation, 64, "layer3")?; + let d3_activation = model.relu(d3)?; + let (logits, (w_out, b_out)) = dense(&mut model, d3_activation, 10, "out_layer")?; + let probabilities = model.softmax(logits)?; + let loss = model.mean_cross_entropy(probabilities, one_hot_labels)?; + let accuracy = model.accuracy(probabilities, sparse_labels)?; + + // ABSTRACT API REQUIREMENT 3: Separate forward/backward pass + // In this construction, the context contains both the forward + // prediction computations and the backward update computations. + // There should be a method for extracting ONLY the forward pass, + // as during inference we do not want to perform the backward computations. + // Part of this issue should be the implementation of optional + // gradient clipping on the backward pass. + let (w1_grad, b1_grad) = (model.diff(loss, w1)?, model.diff(loss, b1)?); + let (w2_grad, b2_grad) = (model.diff(loss, w2)?, model.diff(loss, b2)?); + let (w3_grad, b3_grad) = (model.diff(loss, w3)?, model.diff(loss, b3)?); + let (w_out_grad, b_out_grad) = (model.diff(loss, w_out)?, model.diff(loss, b_out)?); + + // ABSTRACT API REQUIREMENT 4: Separate model/optimizer step + // In general, models and optimizers are thought of as separate + // objects, so should be separate in principle. Additionally, + // with large-scale models we want to be able to compute the + // gradients and then IN PARALLEL 1) compute parameter updates + // and 2) bus the next model input to the device. + // This will require binding XLA operations Send, Recv, + // and potentially OptimizationBarrier. + // Part of this issue should be the implementation of conventional + // optimizers (SGD, RMSProp, Adam), and learning rate schedules + // (ExponentialDecay, ReduceLROnPlateau, CosineAnnealing) + let learning_rate = model.parameter("learning_rate", [], F32)?; + // simple SGD updates + let (w1_update, b1_update) = ( + model.mul(learning_rate, w1_grad)?, + model.mul(learning_rate, b1_grad)?, + ); + let (w2_update, b2_update) = ( + model.mul(learning_rate, w2_grad)?, + model.mul(learning_rate, b2_grad)?, + ); + let (w3_update, b3_update) = ( + model.mul(learning_rate, w3_grad)?, + model.mul(learning_rate, b3_grad)?, + ); + let (w_out_update, b_out_update) = ( + model.mul(learning_rate, w_out_grad)?, + model.mul(learning_rate, b_out_grad)?, + ); + // apply updates + let (w1_new, b1_new ) = ( + model.sub(w1, w1_update)?, + model.sub(b1, b1_update)?, + ); + let (w2_new, b2_new ) = ( + model.sub(w2, w2_update)?, + model.sub(b2, b2_update)?, + ); + let (w3_new, b3_new ) = ( + model.sub(w3, w3_update)?, + model.sub(b3, b3_update)?, + ); + let (w_out_new, b_out_new ) = ( + model.sub(w_out, w_out_update)?, + model.sub(b_out, b_out_update)?, + ); + + model.compile("train_step", + [loss, accuracy, w1_new, b1_new, w2_new, b2_new, w3_new, b3_new, w_out_new, b_out_new], + client) +} + +// ABSTRACT API REQUIREMENT 5: Data prefetching +// Data input to the training loop should be in the form of an iterator. +// These iterators could be finite or infinite. +// In the finite case, we should support random shuffling. +// In either case, batches of data should be pre-fetched in parallel +// (and potentially preprocessed by the CPU) as the training loop is executing. +fn load_mnist_batch(images: File, labels: File, batch_idx: u64) -> io::Result<(xla::Literal, xla::Literal)> { + let mut image_bytes = [0u8; 100*28*28]; + images.read_exact_at(&mut image_bytes, 100*28*28*batch_idx)?; + let mut images_xla = xla::Literal::vec1(&image_bytes); + images_xla = match images_xla.reshape(&[100, 28*28]) { + Ok(x) => x, + Err(_) => panic!("Failed to reshape MNIST image batch!") + }; + + let mut label_bytes = [0u8; 100*4]; + labels.read_exact_at(&mut label_bytes, 100*4*batch_idx)?; + let labels_u32 = label_bytes.chunks(4).map(|c| u32::from_le_bytes(c.try_into().unwrap())).collect::>(); + let labels_xla = xla::Literal::vec1(labels_u32.as_slice()); + + Ok((images_xla, labels_xla)) +} + +fn main() { + println!("Not yet implemented!"); +} diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 87854cd..31567c9 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -98,7 +98,7 @@ impl Context { } Operation::OneHot(node) - | Operation::ReduceArgmax { node: node, dim: _ } => { + | Operation::ReduceArgmax { node, dim: _ } => { if self.gradient_is_dependent(node, dependent_node) { return Err(ContextError::NonDifferentiableOpError( self.nodes[dependent_node].callsite.clone(), diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index c79df8a..9860a9e 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -7,7 +7,7 @@ use slotmap::SlotMap; /// XLA computation graph context. // TODO: rename this to something meaningful pub struct Context { - pub(crate) nodes: SlotMap, + pub nodes: SlotMap, pub(crate) constants: Vec, pub(crate) parameters: Vec, pub(crate) dependent_nodes: HashMap>, @@ -33,7 +33,7 @@ pub enum ContextError { #[error("Tried to call typecast_const on non-constant node at {0}")] NonConstantTypecast(Callsite), - #[error("XLA error: {0}")] + #[error("XLA internal error: {0}. Unless this is a device error, Unda should not produce internal XLA errors. Please create a github issue.")] Xla(#[from] xla::Error), #[error("Unda internal graph processing error {0}")] @@ -54,8 +54,11 @@ pub enum ContextError { #[error("Type is not differentiable, differentiable types are F16, Bf16, F32, F64, C64, C128")] NonDifferentiableTypeError(Callsite), - #[error("Expected integral type")] - IntegralTypeError(Callsite), + #[error("Expected integral type, got {0}. Integral types are S8, S16, S32, S64, U8, U16, U32, U64")] + IntegralTypeError(xla::ElementType, Callsite), + + #[error("Expected real type, got {0}. Real types are F16, Bf16, F32, F64")] + RealTypeError(xla::ElementType, Callsite), #[error("Expected tensor of rank {0}, got {1}")] RankError(usize, usize, Callsite), diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 530c61b..55e6a29 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -778,6 +778,8 @@ impl Context { self.maybe_keepdims(node_id, dim, keepdims) } + // assumes dense_predictions is rank 2 with dimension 0 being batch and dimension 1 being predictions + // assumes sparse_label_vector is rank 1 i64 of class labels pub fn accuracy( &mut self, dense_predictions: NodeIdentifier, @@ -813,7 +815,12 @@ impl Context { | xla::ElementType::U32 | xla::ElementType::S32 | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), - _ => return Err(ContextError::IntegralTypeError(callsite!(1))), + _ => { + return Err(ContextError::IntegralTypeError( + self.nodes[sparse_label_vector].dtype, + callsite!(1), + )) + } }; let node_id = self.nodes.insert(Node { @@ -828,4 +835,36 @@ impl Context { .push(node_id); Ok(node_id) } + + pub fn mean_cross_entropy( + &mut self, + prediction_probabilities: NodeIdentifier, + one_hot_labels: NodeIdentifier, + ) -> Result { + let dtype = self.nodes[prediction_probabilities].dtype; + if dtype != self.nodes[one_hot_labels].dtype { + return Err(ContextError::IncompatibleOperandTypes( + dtype, + self.nodes[one_hot_labels].dtype, + callsite!(1), + )); + } + match dtype { + xla::ElementType::F16 + | xla::ElementType::Bf16 + | xla::ElementType::F32 + | xla::ElementType::F64 => {} + _ => return Err(ContextError::RealTypeError(dtype, callsite!(1))), + } + + let eps = self.scalar(1e-8, xla::ElementType::F32)?; + let eps = self.type_cast(eps, dtype); + // prevent logarithm of zero + let offset = self.add(prediction_probabilities, eps)?; + let log = self.log(offset)?; + let neglog = self.neg(log); + let mul = self.mul(one_hot_labels, neglog)?; + let sum = self.reduce_sum(mul, 1, false)?; + self.reduce_mean(sum, 0, false) + } } diff --git a/src/core/graph/node.rs b/src/core/graph/node.rs index d06a780..24a10c7 100644 --- a/src/core/graph/node.rs +++ b/src/core/graph/node.rs @@ -4,7 +4,7 @@ use super::*; use rand_distr::num_traits::Zero; use slotmap::new_key_type; use xla::Literal; -use std::{fmt::{Display, Formatter, Result}}; +use std::fmt::{Display, Formatter, Result}; use half::bf16; use half::f16; @@ -16,11 +16,11 @@ pub struct Node { // TODO: gate this so its not present at all in release builds pub(crate) callsite: Callsite, /// shape of the output of this node - pub(crate) shape: Shape, + pub shape: Shape, /// the operation this node performs pub(crate) operation: Operation, //// output type of the operation - pub(crate) dtype: xla::ElementType, + pub dtype: xla::ElementType, } @@ -50,7 +50,7 @@ impl Node { return Ok(false); } } - + }, xla::ElementType::F16 => { let data_ref = a.value.to_vec::()?; @@ -179,7 +179,7 @@ impl Node { return Ok(false); } } - + }, xla::ElementType::F16 => { let data_ref = a.value.to_vec::()?; From fb418c5a4f85c2adbda55d8c9fc4fda523dba837 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Fri, 22 Mar 2024 16:55:08 +0100 Subject: [PATCH 05/21] progress on mnist example --- Cargo.lock | 2 +- examples/mnist_xla.rs | 161 +++++++++++++++++++++++++++++++++--------- 2 files changed, 128 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 640c20b..8aea6e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1157,7 +1157,7 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "xla" version = "0.1.6" -source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#b436bb3fed5bbf4e60d406e51070673f0c0de1f0" +source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#61d7d8af8891eabed91b91990a9d09bddd8e86fd" dependencies = [ "bindgen", "cc", diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index e3373b9..035c4af 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -1,23 +1,29 @@ -use unda::core::graph::*; -use xla::{ElementType::*, PjRtLoadedExecutable}; -use std::io; use std::fs::File; +use std::io; use std::os::unix::fs::*; -use byteorder::{LittleEndian, ReadBytesExt}; +use unda::core::graph::*; +use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; + +const USE_CPU: bool = false; +const MNIST_DIRECTORY: &str = "/home/ekadile/mnist"; +const EPOCHS: usize = 20; +const INIT_LEARNING_RATE: f32 = 1e-3; +const LEARNING_RATE_DECAY: f32 = 0.95; // ABSTRACT API REQUIREMENT 1: Automatic Layer Construction // We should have functions like this which, for a given layer type, // automatically resolve shapes and dtypes and construct nodes for // the parameters and outputs of a layer. In the final version, // a function like this should also take an "initialization" parameter -// and run random initialization for the weights and bias. +// and run random initialization for the weights and bias using XLA's +// random number generation functions. fn dense( model: &mut Context, input_node: NodeIdentifier, out_size: u32, name: &str, ) -> Result<(NodeIdentifier, (NodeIdentifier, NodeIdentifier))> { - let shape = model.nodes[input_node].shape; + let shape = model.nodes[input_node].shape.clone(); let last_dim = shape.sizes[shape.ndims() - 1]; let dtype = model.nodes[input_node].dtype; @@ -60,7 +66,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result Result xla::Literal { + let size = shape.size(); + + // this is a goofy initialization which I just thought of + // nobody uses this, but it doesn't matter because MNIST is simple + let amplitude = 1f32 / (size as f32).sqrt(); + let mut vec = Vec::new(); + for i in 0..size { + if i % 2 == 0 { + vec.push(amplitude); + } else { + vec.push(-amplitude); + }; + } + let vec1 = xla::Literal::vec1(vec.as_slice()); + + let xla_shape = shape.sizes.iter().map(|d| *d as i64).collect::>(); + match vec1.reshape(xla_shape.as_slice()) { + Ok(x) => x, + _ => panic!("Failed to reshape initial paramter value!"), + } +} + +// ABSTRACT API REQUIREMENT 5: Parameter structure abstraction +// This relates closely with ABSTRACT API REQUIREMENT 1 +// This example works because I know the exact order in which parameters +// are declared in the model context. This becomes insanely hard +// to keep track of as the architecture grows, and the user shouldn't +// have to worry about it. +fn init_params() -> ( + xla::Literal, + xla::Literal, + xla::Literal, + xla::Literal, + xla::Literal, + xla::Literal, + xla::Literal, + xla::Literal, +) { + ( + init_param(Shape::from([28 * 28, 784])), + init_param(Shape::from([1, 784])), + init_param(Shape::from([784, 256])), + init_param(Shape::from([1, 256])), + init_param(Shape::from([256, 64])), + init_param(Shape::from([1, 64])), + init_param(Shape::from([64, 10])), + init_param(Shape::from([1, 10])), + ) +} + +// ABSTRACT API REQUIREMENT 6: Data prefetching // Data input to the training loop should be in the form of an iterator. // These iterators could be finite or infinite. // In the finite case, we should support random shuffling. -// In either case, batches of data should be pre-fetched in parallel +// In either case, batches of data should be pre-fetched and queued in parallel // (and potentially preprocessed by the CPU) as the training loop is executing. -fn load_mnist_batch(images: File, labels: File, batch_idx: u64) -> io::Result<(xla::Literal, xla::Literal)> { - let mut image_bytes = [0u8; 100*28*28]; - images.read_exact_at(&mut image_bytes, 100*28*28*batch_idx)?; +fn load_mnist_batch( + images: &File, + labels: &File, + batch_idx: u64, +) -> io::Result<(xla::Literal, xla::Literal)> { + let mut image_bytes = [0u8; 100 * 28 * 28]; + images.read_exact_at(&mut image_bytes, 100 * 28 * 28 * batch_idx)?; let mut images_xla = xla::Literal::vec1(&image_bytes); - images_xla = match images_xla.reshape(&[100, 28*28]) { + images_xla = match images_xla.reshape(&[100, 28 * 28]) { Ok(x) => x, - Err(_) => panic!("Failed to reshape MNIST image batch!") + Err(_) => panic!("Failed to reshape MNIST image batch!"), }; - let mut label_bytes = [0u8; 100*4]; - labels.read_exact_at(&mut label_bytes, 100*4*batch_idx)?; - let labels_u32 = label_bytes.chunks(4).map(|c| u32::from_le_bytes(c.try_into().unwrap())).collect::>(); - let labels_xla = xla::Literal::vec1(labels_u32.as_slice()); + let mut label_bytes = [0u8; 100 * 4]; + labels.read_exact_at(&mut label_bytes, 100 * 4 * batch_idx)?; + let labels_xla = xla::Literal::vec1(&label_bytes); Ok((images_xla, labels_xla)) } fn main() { + let client = if USE_CPU { + PjRtClient::cpu().expect("Failed to construct CPU client") + } else { + PjRtClient::gpu(0.9, false).expect("Failed to construct GPU client") + }; + + let mut train_img_path = MNIST_DIRECTORY.to_owned(); + train_img_path.push_str("/train-images-idx3-ubyte"); + let train_images = File::open(train_img_path).expect("Failed to open training image file"); + + let mut train_lbl_path = MNIST_DIRECTORY.to_owned(); + train_lbl_path.push_str("/train-labels-idx1-ubyte"); + let train_labels = File::open(train_lbl_path).expect("Failed to open training label file"); + + let mut test_img_path = MNIST_DIRECTORY.to_owned(); + test_img_path.push_str("/t10k-images-idx3-ubyte"); + let test_images = File::open(test_img_path).expect("Failed to open training image file"); + + let mut test_lbl_path = MNIST_DIRECTORY.to_owned(); + test_lbl_path.push_str("/t10k-labels-idx1-ubyte"); + let test_labels = File::open(test_lbl_path).expect("Failed to open training label file"); + + let executable = + build_model_and_optimizer(&client).expect("Failed to build model and optimizer"); + + let (mut w1, mut b1, mut w2, mut b2, mut w3, mut b3, mut w_out, mut b_out) = init_params(); + + for epoch in 0..EPOCHS { + for batch_idx in 0..600 { + let (train_imgs, train_lbls) = + load_mnist_batch(&train_images, &train_labels, batch_idx) + .expect("Failed to load MNIST batch"); + let lr_literal = + xla::Literal::scalar(INIT_LEARNING_RATE * (LEARNING_RATE_DECAY.powf(epoch as f32))); + } + } + println!("Not yet implemented!"); } From 8914218acaa5ca05f2baa7b302a7d6a2a5a83dc9 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Sat, 23 Mar 2024 00:02:58 +0100 Subject: [PATCH 06/21] example building, untested --- Cargo.lock | 4 +- examples/mnist_xla.rs | 123 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8aea6e9..81fe167 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,9 +45,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.70" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95d8e92cac0961e91dbd517496b00f7e9b92363dbe6d42c3198268323798860c" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 035c4af..fca90c0 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -34,7 +34,7 @@ fn dense( let mut bias_shape = Shape::new(); bias_shape.sizes.push(out_size); - for i in 0..(shape.ndims() - 1) { + for _ in 0..(shape.ndims() - 1) { bias_shape.sizes.push(1u32); } let mut bias_name = name.to_owned(); @@ -246,14 +246,131 @@ fn main() { let (mut w1, mut b1, mut w2, mut b2, mut w3, mut b3, mut w_out, mut b_out) = init_params(); for epoch in 0..EPOCHS { + let mut train_accuracy = 0f32; + let mut train_loss = 0f32; + for batch_idx in 0..600 { let (train_imgs, train_lbls) = load_mnist_batch(&train_images, &train_labels, batch_idx) .expect("Failed to load MNIST batch"); - let lr_literal = + + let lr = xla::Literal::scalar(INIT_LEARNING_RATE * (LEARNING_RATE_DECAY.powf(epoch as f32))); + + // This is where ABSTRACT API REQUIREMENT 5 becomes pertinent + // The user should not have to explicitly reference a dozen parameters like this + let xla_buffer = executable + .execute(&[ + &train_imgs, + &train_lbls, + &w1, + &b1, + &w2, + &b2, + &w3, + &b3, + &w_out, + &b_out, + &lr, + ]) + .expect("Failed to run PjRt executable"); + + // This is where ABSTRACT API REQUIREMENT 4 becomes pertinent + // The user should not have to move all this junk to the host just to get accuracy and loss + let xla_literal = xla_buffer[0][0] + .to_literal_sync() + .expect("Failed to copy buffer to host"); + let untupled_literals = xla_literal + .to_tuple() + .expect("Failed to untuple XLA literals"); + + let loss = untupled_literals[0] + .to_vec::() + .expect("Failed vector conversion of loss")[0]; + train_loss += loss; + let accuracy = untupled_literals[1] + .to_vec::() + .expect("Failed vector conversion of accuracy")[0]; + train_accuracy += accuracy; + + // This is really very silly. Because model/optimizer are not separate + // we move the weights to the CPU just to move them back + // Even without that, is there a way to get rid of the clone?? + (w1, b1, w2, b2, w3, b3, w_out, b_out) = ( + untupled_literals[2].clone(), + untupled_literals[3].clone(), + untupled_literals[4].clone(), + untupled_literals[5].clone(), + untupled_literals[6].clone(), + untupled_literals[7].clone(), + untupled_literals[8].clone(), + untupled_literals[9].clone() + ); } + println!( + "Epoch {}: Training loss = {}; Training accuracy = {}", + epoch, + train_loss / 600f32, + train_accuracy / 600f32 + ); } - println!("Not yet implemented!"); + // ABSTRACT API REQUIREMENT 7: Serialization + // The model is not worth very much if it disappears after our training loop. + // My main suggestion is to serialize the compute graph using XLA HLO + // and serialize the paramters using the npz format. + + let mut test_accuracy = 0f32; + let mut test_loss = 0f32; + + for batch_idx in 0..100 { + let (test_imgs, test_lbls) = + load_mnist_batch(&test_images, &test_labels, batch_idx) + .expect("Failed to load MNIST batch"); + + // GOOFY!! + // Another consequence of ABSTRACT API REQUIREMENT 4 Not being implemented + // To prevent the model from training on the testing data, I have to + // set the learning rate to zero + let lr = xla::Literal::scalar(0f32); + + let xla_buffer = executable + .execute(&[ + &test_imgs, + &test_lbls, + &w1, + &b1, + &w2, + &b2, + &w3, + &b3, + &w_out, + &b_out, + &lr, + ]) + .expect("Failed to run PjRt executable"); + + // This is where ABSTRACT API REQUIREMENT 4 becomes pertinent + // The user should not have to move all this junk to the host just to get accuracy and loss + let xla_literal = xla_buffer[0][0] + .to_literal_sync() + .expect("Failed to copy buffer to host"); + let untupled_literals = xla_literal + .to_tuple() + .expect("Failed to untuple XLA literals"); + + let loss = untupled_literals[0] + .to_vec::() + .expect("Failed vector conversion of loss")[0]; + test_loss += loss; + let accuracy = untupled_literals[1] + .to_vec::() + .expect("Failed vector conversion of accuracy")[0]; + test_accuracy += accuracy; + } + println!( + "Testing loss = {}; Testing accuracy = {}", + test_loss / 100f32, + test_accuracy / 100f32 + ); } From cb2a9fe08e82f55ff84ed68d17cb2f3f421daaf6 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Sun, 24 Mar 2024 23:45:44 +0100 Subject: [PATCH 07/21] debugging mnist --- examples/mnist_xla.rs | 18 +++++++++++------- src/core/graph/math.rs | 29 +++++++++++++++++++---------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index fca90c0..d4d470f 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -5,10 +5,10 @@ use unda::core::graph::*; use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; const USE_CPU: bool = false; -const MNIST_DIRECTORY: &str = "/home/ekadile/mnist"; +const MNIST_DIRECTORY: &str = "/home/medusa/mnist"; const EPOCHS: usize = 20; const INIT_LEARNING_RATE: f32 = 1e-3; -const LEARNING_RATE_DECAY: f32 = 0.95; +const LEARNING_RATE_DECAY: f32 = 0.9; // ABSTRACT API REQUIREMENT 1: Automatic Layer Construction // We should have functions like this which, for a given layer type, @@ -24,24 +24,27 @@ fn dense( name: &str, ) -> Result<(NodeIdentifier, (NodeIdentifier, NodeIdentifier))> { let shape = model.nodes[input_node].shape.clone(); + //println!("{:?}", shape.sizes); let last_dim = shape.sizes[shape.ndims() - 1]; let dtype = model.nodes[input_node].dtype; - let weights_shape = Shape::from([out_size, last_dim]); + //println!("weight shape {} {}", last_dim, out_size); + let weights_shape = Shape::from([last_dim, out_size]); let mut weights_name = name.to_owned(); weights_name.push_str("_weights"); let weights = model.parameter(weights_name, weights_shape, dtype)?; let mut bias_shape = Shape::new(); - bias_shape.sizes.push(out_size); for _ in 0..(shape.ndims() - 1) { bias_shape.sizes.push(1u32); } + bias_shape.sizes.push(out_size); + //println!("bias shape {:?}", bias_shape.sizes); let mut bias_name = name.to_owned(); bias_name.push_str("_bias"); let bias = model.parameter(bias_name, bias_shape, dtype)?; - let matmul_node = model.matmul(weights, input_node)?; + let matmul_node = model.matmul(input_node, weights)?; let dense_node = model.add(matmul_node, bias)?; Ok((dense_node, (weights, bias))) @@ -88,6 +91,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result io::Result<(xla::Literal, xla::Literal)> { let mut image_bytes = [0u8; 100 * 28 * 28]; - images.read_exact_at(&mut image_bytes, 100 * 28 * 28 * batch_idx)?; + images.read_exact_at(&mut image_bytes, 8 + 100 * 28 * 28 * batch_idx)?; let mut images_xla = xla::Literal::vec1(&image_bytes); images_xla = match images_xla.reshape(&[100, 28 * 28]) { Ok(x) => x, @@ -211,7 +215,7 @@ fn load_mnist_batch( }; let mut label_bytes = [0u8; 100 * 4]; - labels.read_exact_at(&mut label_bytes, 100 * 4 * batch_idx)?; + labels.read_exact_at(&mut label_bytes, 8 + 100 * 4 * batch_idx)?; let labels_xla = xla::Literal::vec1(&label_bytes); Ok((images_xla, labels_xla)) diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 55e6a29..91572d9 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -756,15 +756,8 @@ impl Context { dim: i64, keepdims: bool, ) -> Result { - let a = a.into(); - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim && keepdims { - s.sizes.push(1) - } else { - s.sizes.push(self.nodes[a].shape.sizes[d]) - } - } + let mut s = self.nodes[a].shape.clone(); + s.sizes.remove(dim as usize); let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -785,8 +778,24 @@ impl Context { dense_predictions: NodeIdentifier, sparse_label_vector: NodeIdentifier, ) -> Result { + let converted_labels = match self.nodes[sparse_label_vector].dtype { + xla::ElementType::S64 => sparse_label_vector, + xla::ElementType::U8 + | xla::ElementType::S8 + | xla::ElementType::U16 + | xla::ElementType::S16 + | xla::ElementType::U32 + | xla::ElementType::S32 + | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), + _ => { + return Err(ContextError::IntegralTypeError( + self.nodes[sparse_label_vector].dtype, + callsite!(1), + )) + } + }; let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false)?; - let compare = self.eq(sparse_predictions, sparse_label_vector)?; + let compare = self.eq(sparse_predictions, converted_labels)?; let converted = self.type_cast(compare, xla::ElementType::F32); self.reduce_mean(converted, 0, false) } From fdac5dc2cbcdc12a093688c66cea0ae59bc6f712 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Mon, 25 Mar 2024 10:59:40 +0100 Subject: [PATCH 08/21] debugging autodiff --- examples/mnist_xla.rs | 14 +- src/core/graph/autodiff.rs | 438 ++++++++++++++++++------------------- src/core/graph/math.rs | 14 +- 3 files changed, 221 insertions(+), 245 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index d4d470f..962a459 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -5,7 +5,7 @@ use unda::core::graph::*; use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; const USE_CPU: bool = false; -const MNIST_DIRECTORY: &str = "/home/medusa/mnist"; +const MNIST_DIRECTORY: &str = "/home/ekadile/mnist"; const EPOCHS: usize = 20; const INIT_LEARNING_RATE: f32 = 1e-3; const LEARNING_RATE_DECAY: f32 = 0.9; @@ -24,11 +24,9 @@ fn dense( name: &str, ) -> Result<(NodeIdentifier, (NodeIdentifier, NodeIdentifier))> { let shape = model.nodes[input_node].shape.clone(); - //println!("{:?}", shape.sizes); let last_dim = shape.sizes[shape.ndims() - 1]; let dtype = model.nodes[input_node].dtype; - //println!("weight shape {} {}", last_dim, out_size); let weights_shape = Shape::from([last_dim, out_size]); let mut weights_name = name.to_owned(); weights_name.push_str("_weights"); @@ -39,7 +37,6 @@ fn dense( bias_shape.sizes.push(1u32); } bias_shape.sizes.push(out_size); - //println!("bias shape {:?}", bias_shape.sizes); let mut bias_name = name.to_owned(); bias_name.push_str("_bias"); let bias = model.parameter(bias_name, bias_shape, dtype)?; @@ -81,6 +78,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result Result Result Result xla::Literal { let size = shape.size(); - // this is a goofy initialization which I just thought of + // deterministic version of xavier initialization // nobody uses this, but it doesn't matter because MNIST is simple - let amplitude = 1f32 / (size as f32).sqrt(); + let amplitude = 1f32 / ((shape.sizes[0] + shape.sizes[1]) as f32).sqrt(); let mut vec = Vec::new(); for i in 0..size { if i % 2 == 0 { diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 31567c9..23fb874 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -10,10 +10,7 @@ impl Context { operation: Operation::StopGradient(node), dtype: self.nodes[node].dtype, }); - self.dependent_nodes - .entry(node) - .or_default() - .push(new_node); + self.dependent_nodes.entry(node).or_default().push(new_node); new_node } @@ -75,295 +72,284 @@ impl Context { let dependent_nodes = self.dependent_nodes[&with_respect_to].clone(); for dependent_node in dependent_nodes { - //Again again, clone() here is not wonderful, there's gotta be a better way to - //store the i64 vec for Transpose - match self.nodes[dependent_node].operation.clone() { - Operation::Constant(_) => panic!("Constant found as dependent node!"), - Operation::Parameter(_) => panic!("Parameter found as dependent node!"), - Operation::StopGradient(_) => continue, - - Operation::Equal(_, _) - | Operation::NotEqual(_, _) - | Operation::LessThan(_, _) - | Operation::LessThanEq(_, _) - | Operation::GreaterThan(_, _) - | Operation::GreaterThanEq(_, _) => { - if self.gradient_is_dependent(output, dependent_node) { + if self.gradient_is_dependent(output, dependent_node) { + //Again again, clone() here is not wonderful, there's gotta be a better way to + //store the i64 vec for Transpose + match self.nodes[dependent_node].operation.clone() { + Operation::Constant(_) => panic!("Constant found as dependent node!"), + Operation::Parameter(_) => panic!("Parameter found as dependent node!"), + Operation::StopGradient(_) => continue, + + Operation::Equal(_, _) + | Operation::NotEqual(_, _) + | Operation::LessThan(_, _) + | Operation::LessThanEq(_, _) + | Operation::GreaterThan(_, _) + | Operation::GreaterThanEq(_, _) => { return Err(ContextError::NonDifferentiableOpError( self.nodes[dependent_node].callsite.clone(), )); - } else { - continue; } - } - Operation::OneHot(node) - | Operation::ReduceArgmax { node, dim: _ } => { - if self.gradient_is_dependent(node, dependent_node) { + Operation::OneHot(_) + | Operation::ReduceArgmax { node: _, dim: _ } + | Operation::TypeCast(_, _) => { return Err(ContextError::NonDifferentiableOpError( self.nodes[dependent_node].callsite.clone(), )); - } else { - continue; } - } - Operation::TypeCast(_, _) => { - if self.gradient_is_dependent(output, dependent_node) { - return Err(ContextError::NonDifferentiableOpError( - self.nodes[dependent_node].callsite.clone(), - )); - } else { - continue; - } - } - - Operation::Reshape(node) => { - let next_pullback = self.diff(output, dependent_node)?; - let node_sh = self.nodes[node].shape.clone(); - let pullback = self.reshape(next_pullback, node_sh)?; - dependent_pullbacks.push(pullback); - } - - Operation::Transpose(a, p) => { - if a == with_respect_to { + Operation::Reshape(node) => { let next_pullback = self.diff(output, dependent_node)?; - let inv_perm = Context::inv_perm(&p); - - let pullback = self.transpose(next_pullback, &inv_perm)?; + let node_sh = self.nodes[node].shape.clone(); + println!("Reshape {} \n{}", self.to_string(self.dependent_nodes[&dependent_node][0]), self.nodes[dependent_node].shape); + let pullback = self.reshape(next_pullback, node_sh)?; dependent_pullbacks.push(pullback); } - } - Operation::ZerosLike(_) => continue, + Operation::Transpose(a, p) => { + if a == with_respect_to { + let next_pullback = self.diff(output, dependent_node)?; + let inv_perm = Context::inv_perm(&p); - Operation::Add(a, b) => { - if a == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); - } - if b == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); + let pullback = self.transpose(next_pullback, &inv_perm)?; + dependent_pullbacks.push(pullback); + } } - } - Operation::Sub(a, b) => { - if a == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); - } - if b == with_respect_to { - let next_pullback = self.diff(output, dependent_node)?; - dependent_pullbacks.push(self.neg(next_pullback)); - } - } + Operation::ZerosLike(_) => continue, - Operation::Mul(a, b) => { - let next_pullback = self.diff(output, dependent_node)?; - if a == b && a == with_respect_to { - let two = self.scalar(2, wrt_dtype)?; - let mul = self.mul(two, a)?; - dependent_pullbacks.push(self.mul(mul, next_pullback)?); - } else if a == with_respect_to { - let mul = self.mul(next_pullback, a)?; - dependent_pullbacks.push(mul); - } else if b == with_respect_to { - let mul = self.mul(a, next_pullback)?; - dependent_pullbacks.push(mul); + Operation::Add(a, b) => { + println!("Add"); + if a == with_respect_to { + dependent_pullbacks.push(self.diff(output, dependent_node)?); + } + if b == with_respect_to { + dependent_pullbacks.push(self.diff(output, dependent_node)?); + } } - } - - Operation::MatMul(a, b) => { - let next_pullback = self.diff(output, dependent_node)?; - if a == with_respect_to { - let mut transpose_dims: Vec = vec![]; - - let size = self.nodes[b].shape.sizes.len(); - - for &dim in self.nodes[b].shape.sizes.iter() { - transpose_dims.push(dim as i64); + Operation::Sub(a, b) => { + if a == with_respect_to { + dependent_pullbacks.push(self.diff(output, dependent_node)?); } - - transpose_dims.swap(size - 2, size - 1); - - let transpose = self.transpose(b, &transpose_dims)?; - let this_pullback = self.mul(transpose, next_pullback)?; - dependent_pullbacks.push(this_pullback); - } else if b == with_respect_to { - let mut transpose_dims: Vec = vec![]; - - let size = self.nodes[a].shape.sizes.len(); - - for &dim in self.nodes[a].shape.sizes.iter() { - transpose_dims.push(dim as i64); + if b == with_respect_to { + let next_pullback = self.diff(output, dependent_node)?; + dependent_pullbacks.push(self.neg(next_pullback)); } + } - transpose_dims.swap(size - 2, size - 1); - - let transpose = self.transpose(a, &transpose_dims)?; - let this_pullback = self.mul(transpose, next_pullback)?; - dependent_pullbacks.push(this_pullback); + Operation::Mul(a, b) => { + let next_pullback = self.diff(output, dependent_node)?; + println!("Mul {} {}", self.nodes[a].shape, self.nodes[b].shape); + if a == b && a == with_respect_to { + let two = self.scalar(2, wrt_dtype)?; + let mul = self.mul(two, a)?; + dependent_pullbacks.push(self.mul(mul, next_pullback)?); + } else if a == with_respect_to { + let mul = self.mul(next_pullback, a)?; + dependent_pullbacks.push(mul); + } else if b == with_respect_to { + let mul = self.mul(a, next_pullback)?; + dependent_pullbacks.push(mul); + } } - } - Operation::Div(a, b) => { - let next_pullback = self.diff(output, dependent_node)?; - if a == with_respect_to { - let div = self.div(next_pullback, b)?; - dependent_pullbacks.push(div); + Operation::MatMul(a, b) => { + let next_pullback = self.diff(output, dependent_node)?; + println!("got to matmul"); + + if a == with_respect_to { + let mut transpose_dims = self.nodes[a] + .shape + .sizes + .iter() + .map(|s| *s as i64) + .collect::>(); + let ndims = transpose_dims.len(); + transpose_dims.swap(ndims - 2, ndims - 1); + + let transpose = self.transpose(a, transpose_dims.as_slice())?; + let this_pullback = self.matmul(transpose, next_pullback)?; + dependent_pullbacks.push(this_pullback); + } else if b == with_respect_to { + let mut transpose_dims = self.nodes[b] + .shape + .sizes + .iter() + .map(|s| *s as i64) + .collect::>(); + let ndims = transpose_dims.len(); + transpose_dims.swap(ndims - 2, ndims - 1); + + let transpose = self.transpose(b, &transpose_dims)?; + let this_pullback = self.matmul(next_pullback, transpose)?; + dependent_pullbacks.push(this_pullback); + } } - if b == with_respect_to { - let mul = self.mul(b, b)?; - let div = self.div(a, mul)?; - let neg = self.neg(div); - let this_pullback = self.mul(neg, next_pullback)?; - dependent_pullbacks.push(this_pullback); + + Operation::Div(a, b) => { + let next_pullback = self.diff(output, dependent_node)?; + if a == with_respect_to { + let div = self.div(next_pullback, b)?; + dependent_pullbacks.push(div); + } + if b == with_respect_to { + let mul = self.mul(b, b)?; + let div = self.div(a, mul)?; + let neg = self.neg(div); + let this_pullback = self.mul(neg, next_pullback)?; + dependent_pullbacks.push(this_pullback); + } } - } - Operation::Pow(a, b) => { - let next_pullback = self.diff(output, dependent_node)?; - if a == with_respect_to { - let one = self.scalar(1, wrt_dtype)?; - let b_min_one = self.sub(b, one)?; + Operation::Pow(a, b) => { + let next_pullback = self.diff(output, dependent_node)?; + if a == with_respect_to { + let one = self.scalar(1, wrt_dtype)?; + let b_min_one = self.sub(b, one)?; - let new_pow = self.pow(a, b_min_one)?; - let power_rule = self.mul(b, new_pow)?; + let new_pow = self.pow(a, b_min_one)?; + let power_rule = self.mul(b, new_pow)?; - let this_pullback = self.mul(power_rule, next_pullback)?; - dependent_pullbacks.push(this_pullback); - } - if b == with_respect_to { - let log_a = self.log(a)?; - let log_times_orig = self.mul(log_a, dependent_node)?; - let this_pullback = self.mul(log_times_orig, next_pullback)?; + let this_pullback = self.mul(power_rule, next_pullback)?; + dependent_pullbacks.push(this_pullback); + } + if b == with_respect_to { + let log_a = self.log(a)?; + let log_times_orig = self.mul(log_a, dependent_node)?; + let this_pullback = self.mul(log_times_orig, next_pullback)?; - dependent_pullbacks.push(this_pullback); + dependent_pullbacks.push(this_pullback); + } } - } - Operation::Log(a) => { - if a == with_respect_to { + Operation::Log(a) => { let next_pullback = self.diff(output, dependent_node)?; let one = self.scalar(1, wrt_dtype)?; let quotient = self.div(one, a)?; + println!("Log"); let next_pullback = self.mul(quotient, next_pullback)?; dependent_pullbacks.push(next_pullback); } - } - Operation::Neg(_) => { - let next_pullback = self.diff(output, dependent_node)?; - dependent_pullbacks.push(self.neg(next_pullback)); - }, + Operation::Neg(_) => { + println!("Neg"); + let next_pullback = self.diff(output, dependent_node)?; + dependent_pullbacks.push(self.neg(next_pullback)); + } - Operation::Exp(a) => { - if a == with_respect_to { + Operation::Exp(_) => { let next_pullback = self.diff(output, dependent_node)?; + println!("Exp"); let this_pullback = self.mul(next_pullback, dependent_node)?; dependent_pullbacks.push(this_pullback); } - } - Operation::TileInDim { node, n_tiles, dim } => { - let next_pullback = self.diff(output, dependent_node)?; + Operation::TileInDim { node, n_tiles, dim } => { + let next_pullback = self.diff(output, dependent_node)?; - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[node].shape.ndims()).rev() { - new_sizes.push(self.nodes[node].shape.sizes[i]); - if i as i64 == dim { - new_sizes.push(n_tiles as u32); + let mut new_sizes = SmallVec::new(); + for i in (0..self.nodes[node].shape.ndims()).rev() { + new_sizes.push(self.nodes[node].shape.sizes[i]); + if i as i64 == dim { + new_sizes.push(n_tiles as u32); + } } - } - let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; - dependent_pullbacks.push(self.reduce_sum(reshaped_pullback, dim, false)?); - } + let reshaped_pullback = + self.reshape(next_pullback, Shape { sizes: new_sizes })?; + dependent_pullbacks.push(self.reduce_sum( + reshaped_pullback, + dim, + false, + )?); + } - Operation::SliceInDim { - node, - start: _, - stop: _, - stride: _, - dim: _, - } => { - if self.gradient_is_dependent(node, dependent_node) { - panic!( + Operation::SliceInDim { + node, + start: _, + stop: _, + stride: _, + dim: _, + } => { + if self.gradient_is_dependent(node, dependent_node) { + panic!( "Gradient of SliceInDim requires XLA scatter op to be implemented." ); - } else { - continue; + } else { + continue; + } } - } - Operation::ReduceMax { - node, - dim: _, - } => { - if self.gradient_is_dependent(node, dependent_node) { - panic!( + Operation::ReduceMax { node, dim: _ } => { + if self.gradient_is_dependent(node, dependent_node) { + panic!( "Gradient of ReduceMax requires XLA scatter op to be implemented." ); - } else { - continue; + } else { + continue; + } } - } - - Operation::ReduceSum { - node, - dim, - } => { - let next_pullback = self.diff(output, dependent_node)?; - let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; - let mut new_shape = self.nodes[next_pullback].shape.clone(); - new_shape.sizes.insert(dim as usize, 1u32); - let reshaped_pullback = - self.reshape(next_pullback, new_shape)?; - let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; + Operation::ReduceSum { node, dim } => { + let next_pullback = self.diff(output, dependent_node)?; + let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; - dependent_pullbacks.push(tiled_pullback); - } + let mut new_shape = self.nodes[next_pullback].shape.clone(); + new_shape.sizes.insert(dim as usize, 1u32); + println!("got to reducesum"); + let reshaped_pullback = + self.reshape(next_pullback, new_shape.clone())?; + let tiled_pullback = + self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; - Operation::ReduceMean { node, dim } => { - let next_pullback = self.diff(output, dependent_node)?; - let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; + dependent_pullbacks.push(tiled_pullback); + } - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { - new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); - if i as i64 == dim { + Operation::ReduceMean { node, dim } => { + let next_pullback = self.diff(output, dependent_node)?; + let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; + + let mut new_sizes = SmallVec::new(); + for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { + new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); + if i as i64 == dim { + new_sizes.push(1u32); + } + } + if self.nodes[next_pullback].shape.ndims() == 0 { new_sizes.push(1u32); } + println!("got to reducemean"); + let reshaped_pullback = + self.reshape(next_pullback, Shape { sizes: new_sizes })?; + let tiled_pullback = + self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; + + let scale = + self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?; + let rescaled_pullback = self.mul(scale, tiled_pullback)?; + dependent_pullbacks.push(rescaled_pullback); } - if self.nodes[next_pullback].shape.ndims() == 0 { - new_sizes.push(1u32); - } - let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; - let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; - - let scale = self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?; - let rescaled_pullback = self.mul(scale, tiled_pullback)?; - dependent_pullbacks.push(rescaled_pullback); - } - Operation::Select { - pred, - on_true, - on_false, - } => { - let next_pullback = self.diff(output, dependent_node)?; - let const_zero = self.scalar(0, wrt_dtype)?; - if on_true == with_respect_to { - let select = self.select(pred, next_pullback, const_zero)?; - dependent_pullbacks.push(select); - } - if on_false == with_respect_to { - let select = self.select(pred, const_zero, next_pullback)?; - dependent_pullbacks.push(select); + Operation::Select { + pred, + on_true, + on_false, + } => { + let next_pullback = self.diff(output, dependent_node)?; + let const_zero = self.scalar(0, wrt_dtype)?; + if on_true == with_respect_to { + let select = self.select(pred, next_pullback, const_zero)?; + dependent_pullbacks.push(select); + } + if on_false == with_respect_to { + let select = self.select(pred, const_zero, next_pullback)?; + dependent_pullbacks.push(select); + } } } } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 91572d9..18ab433 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -635,15 +635,8 @@ impl Context { n_tiles: i64, dim: i64, ) -> Result { - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim { - s.sizes - .push((n_tiles as u32) * self.nodes[a].shape.sizes[d]); - } else { - s.sizes.push(self.nodes[a].shape.sizes[d]); - } - } + let mut s = self.nodes[a].shape.clone(); + s.sizes[dim as usize] *= n_tiles as u32; let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -866,8 +859,7 @@ impl Context { _ => return Err(ContextError::RealTypeError(dtype, callsite!(1))), } - let eps = self.scalar(1e-8, xla::ElementType::F32)?; - let eps = self.type_cast(eps, dtype); + let eps = self.scalar(1e-8, dtype)?; // prevent logarithm of zero let offset = self.add(prediction_probabilities, eps)?; let log = self.log(offset)?; From 0fe853bb5be138800e2a2d842e375b859f5cfd04 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Mon, 25 Mar 2024 11:52:38 +0100 Subject: [PATCH 09/21] debugged reshape and reduce_sum derivatives --- src/core/graph/autodiff.rs | 67 +++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 23fb874..ad29b57 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -102,9 +102,40 @@ impl Context { Operation::Reshape(node) => { let next_pullback = self.diff(output, dependent_node)?; let node_sh = self.nodes[node].shape.clone(); - println!("Reshape {} \n{}", self.to_string(self.dependent_nodes[&dependent_node][0]), self.nodes[dependent_node].shape); - let pullback = self.reshape(next_pullback, node_sh)?; - dependent_pullbacks.push(pullback); + if self.nodes[next_pullback].shape.size() == node_sh.size() { + let pullback = self.reshape(next_pullback, node_sh)?; + println!("Reshape {}", self.nodes[next_pullback].shape); + dependent_pullbacks.push(pullback); + } else if let Some(s) = self.nodes[next_pullback] + .shape + .broadcast(&self.nodes[dependent_node].shape) + { + // this is the case where the gradients picked up some + // broadcasted batch dimensions along the way + let mut sum_pullback = next_pullback; + if self.nodes[dependent_node].shape.sizes.is_empty() { + for _ in 0..self.nodes[next_pullback].shape.ndims() { + sum_pullback = self.reduce_sum(next_pullback, 0, false)?; + } + println!("Reshape {}", self.nodes[sum_pullback].shape); + dependent_pullbacks.push(sum_pullback); + } else { + for d in 0..self.nodes[next_pullback].shape.ndims() { + if self.nodes[dependent_node].shape.sizes[d] == 1 { + sum_pullback = + self.reduce_sum(next_pullback, d as i64, true)? + } + } + println!("Reshape {}", self.nodes[sum_pullback].shape); + dependent_pullbacks.push(sum_pullback); + } + } else { + return Err(ContextError::IncompatibleOperandShapes( + self.nodes[next_pullback].shape.clone(), + self.nodes[dependent_node].shape.clone(), + callsite!(1), + )); + } } Operation::Transpose(a, p) => { @@ -120,12 +151,13 @@ impl Context { Operation::ZerosLike(_) => continue, Operation::Add(a, b) => { - println!("Add"); + let next_pullback = self.diff(output, dependent_node)?; + println!("Add {}", self.nodes[next_pullback].shape); if a == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); + dependent_pullbacks.push(next_pullback); } if b == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); + dependent_pullbacks.push(next_pullback); } } @@ -299,13 +331,16 @@ impl Context { let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; let mut new_shape = self.nodes[next_pullback].shape.clone(); - new_shape.sizes.insert(dim as usize, 1u32); - println!("got to reducesum"); + if new_shape.ndims() != self.nodes[node].shape.ndims() { + new_shape.sizes.insert(dim as usize, 1u32); + } let reshaped_pullback = self.reshape(next_pullback, new_shape.clone())?; + println!("{}", new_shape); let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; + println!("ReduceSum {}", self.nodes[tiled_pullback].shape); dependent_pullbacks.push(tiled_pullback); } @@ -313,25 +348,19 @@ impl Context { let next_pullback = self.diff(output, dependent_node)?; let n_tiles = self.nodes[node].shape.sizes[dim as usize] as i64; - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[next_pullback].shape.ndims()).rev() { - new_sizes.push(self.nodes[next_pullback].shape.sizes[i]); - if i as i64 == dim { - new_sizes.push(1u32); - } - } - if self.nodes[next_pullback].shape.ndims() == 0 { - new_sizes.push(1u32); + let mut new_shape = self.nodes[next_pullback].shape.clone(); + if new_shape.ndims() != self.nodes[node].shape.ndims() { + new_shape.sizes.insert(dim as usize, 1u32); } - println!("got to reducemean"); let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; + self.reshape(next_pullback, new_shape)?; let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; let scale = self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?; let rescaled_pullback = self.mul(scale, tiled_pullback)?; + println!("ReduceMean {}", self.nodes[rescaled_pullback].shape); dependent_pullbacks.push(rescaled_pullback); } From 5b964d3d8772a4ef5118ee5279e36a5fad906f58 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Mon, 25 Mar 2024 13:26:46 +0100 Subject: [PATCH 10/21] shape checker failure somewhere --- examples/mnist_xla.rs | 7 ++++++- src/core/graph/autodiff.rs | 41 ++++++++++---------------------------- src/core/graph/context.rs | 3 +++ src/core/graph/math.rs | 16 +++++++++++++-- 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 962a459..ff5c89d 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -1,5 +1,6 @@ use std::fs::File; use std::io; +use std::time::Instant; use std::os::unix::fs::*; use unda::core::graph::*; use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; @@ -131,7 +132,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result { let next_pullback = self.diff(output, dependent_node)?; - println!("Add {}", self.nodes[next_pullback].shape); if a == with_respect_to { dependent_pullbacks.push(next_pullback); } @@ -173,7 +169,6 @@ impl Context { Operation::Mul(a, b) => { let next_pullback = self.diff(output, dependent_node)?; - println!("Mul {} {}", self.nodes[a].shape, self.nodes[b].shape); if a == b && a == with_respect_to { let two = self.scalar(2, wrt_dtype)?; let mul = self.mul(two, a)?; @@ -189,33 +184,24 @@ impl Context { Operation::MatMul(a, b) => { let next_pullback = self.diff(output, dependent_node)?; - println!("got to matmul"); if a == with_respect_to { - let mut transpose_dims = self.nodes[a] - .shape - .sizes - .iter() - .map(|s| *s as i64) - .collect::>(); + let mut transpose_dims = + Vec::from_iter(0..(self.nodes[b].shape.ndims() as i64)); let ndims = transpose_dims.len(); transpose_dims.swap(ndims - 2, ndims - 1); - let transpose = self.transpose(a, transpose_dims.as_slice())?; - let this_pullback = self.matmul(transpose, next_pullback)?; + let transpose = self.transpose(b, &transpose_dims)?; + let this_pullback = self.matmul(next_pullback, transpose)?; dependent_pullbacks.push(this_pullback); } else if b == with_respect_to { - let mut transpose_dims = self.nodes[b] - .shape - .sizes - .iter() - .map(|s| *s as i64) - .collect::>(); + let mut transpose_dims = + Vec::from_iter(0..(self.nodes[a].shape.ndims() as i64)); let ndims = transpose_dims.len(); transpose_dims.swap(ndims - 2, ndims - 1); - let transpose = self.transpose(b, &transpose_dims)?; - let this_pullback = self.matmul(next_pullback, transpose)?; + let transpose = self.transpose(a, &transpose_dims)?; + let this_pullback = self.matmul(transpose, next_pullback)?; dependent_pullbacks.push(this_pullback); } } @@ -261,20 +247,17 @@ impl Context { let one = self.scalar(1, wrt_dtype)?; let quotient = self.div(one, a)?; - println!("Log"); let next_pullback = self.mul(quotient, next_pullback)?; dependent_pullbacks.push(next_pullback); } Operation::Neg(_) => { - println!("Neg"); let next_pullback = self.diff(output, dependent_node)?; dependent_pullbacks.push(self.neg(next_pullback)); } Operation::Exp(_) => { let next_pullback = self.diff(output, dependent_node)?; - println!("Exp"); let this_pullback = self.mul(next_pullback, dependent_node)?; dependent_pullbacks.push(this_pullback); @@ -336,11 +319,9 @@ impl Context { } let reshaped_pullback = self.reshape(next_pullback, new_shape.clone())?; - println!("{}", new_shape); let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; - println!("ReduceSum {}", self.nodes[tiled_pullback].shape); dependent_pullbacks.push(tiled_pullback); } @@ -352,15 +333,13 @@ impl Context { if new_shape.ndims() != self.nodes[node].shape.ndims() { new_shape.sizes.insert(dim as usize, 1u32); } - let reshaped_pullback = - self.reshape(next_pullback, new_shape)?; + let reshaped_pullback = self.reshape(next_pullback, new_shape)?; let tiled_pullback = self.tile_in_dim(reshaped_pullback, n_tiles, dim)?; let scale = self.scalar(1.0 / (n_tiles as f32), self.nodes[node].dtype)?; let rescaled_pullback = self.mul(scale, tiled_pullback)?; - println!("ReduceMean {}", self.nodes[rescaled_pullback].shape); dependent_pullbacks.push(rescaled_pullback); } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 9860a9e..f0cd559 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -62,6 +62,9 @@ pub enum ContextError { #[error("Expected tensor of rank {0}, got {1}")] RankError(usize, usize, Callsite), + + #[error("Invalid permutation passed to transpose. Expected permutation of length {0}, got {1}")] + TransposeLenError(usize, usize, Callsite), } pub type Result = std::result::Result; diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 18ab433..fc0b5d7 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -584,11 +584,23 @@ impl Context { } pub fn transpose(&mut self, a: NodeIdentifier, index_perm: &[i64]) -> Result { - let a_shape = self.nodes[a].shape.clone(); + if index_perm.len() != self.nodes[a].shape.ndims() { + return Err(ContextError::TransposeLenError( + self.nodes[a].shape.ndims(), + index_perm.len(), + callsite!(1), + )); + } + let mut new_shape = Shape::new(); + for d in 0..index_perm.len() { + new_shape + .sizes + .push(self.nodes[a].shape.sizes[index_perm[d] as usize]); + } let index_perms_deref = index_perm.to_vec(); let node_id = self.nodes.insert(Node { callsite: callsite!(1), - shape: a_shape, + shape: new_shape, operation: Operation::Transpose(a, index_perms_deref), dtype: self.nodes[a].dtype, }); From d1a7b625c113b38e8ec4a79dd8545d1dce62409b Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Mon, 25 Mar 2024 15:51:43 +0100 Subject: [PATCH 11/21] more reshaping problems --- src/core/graph/autodiff.rs | 12 +++----- src/core/graph/compile.rs | 34 +++++++++-------------- src/core/graph/context.rs | 3 ++ src/core/graph/dtypes.rs | 37 +++++++++++++++++++++++++ src/core/graph/math.rs | 57 ++++++++++++-------------------------- src/core/graph/mod.rs | 1 + 6 files changed, 75 insertions(+), 69 deletions(-) create mode 100644 src/core/graph/dtypes.rs diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index f8b4013..003b927 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -263,19 +263,15 @@ impl Context { dependent_pullbacks.push(this_pullback); } + // TODO: Test this! Operation::TileInDim { node, n_tiles, dim } => { let next_pullback = self.diff(output, dependent_node)?; - let mut new_sizes = SmallVec::new(); - for i in (0..self.nodes[node].shape.ndims()).rev() { - new_sizes.push(self.nodes[node].shape.sizes[i]); - if i as i64 == dim { - new_sizes.push(n_tiles as u32); - } - } + let mut new_shape = self.nodes[node].shape.clone(); + new_shape.sizes.insert(dim as usize, n_tiles as u32); let reshaped_pullback = - self.reshape(next_pullback, Shape { sizes: new_sizes })?; + self.reshape(next_pullback, new_shape)?; dependent_pullbacks.push(self.reduce_sum( reshaped_pullback, dim, diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index 7194e8e..2c2d9cf 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -31,7 +31,6 @@ impl Context { Err(CompileError::NoReturn)?; } - for a in returns.iter() { self.fold_consts(*a, usize::MAX)?; } @@ -185,11 +184,10 @@ impl Context { } Operation::Transpose(a, perm_index) => { - if unda_xla_map.contains_key(&a) + if unda_xla_map.contains_key(&a) && xla_op_slotmap.contains_key(unda_xla_map[&a]) { - let xla_op = xla_op_slotmap[unda_xla_map[&a]] - .transpose(&perm_index)?; + let xla_op = xla_op_slotmap[unda_xla_map[&a]].transpose(&perm_index)?; let xla_id = xla_op_slotmap.insert(xla_op); unda_xla_map.insert(*dependent_op, xla_id); unda_op_queue.push_back(*dependent_op); @@ -239,7 +237,6 @@ impl Context { } } - Operation::Exp(a) => { if unda_xla_map.contains_key(&a) && xla_op_slotmap.contains_key(unda_xla_map[&a]) @@ -389,6 +386,13 @@ impl Context { if unda_xla_map.contains_key(&node) && xla_op_slotmap.contains_key(unda_xla_map[&node]) { + println!( + "{} {} {} {}", + self.nodes[node].shape, + this_node.shape, + self.nodes[node].operation, + self.nodes[node].callsite + ); let xla_op = xla_op_slotmap[unda_xla_map[&node]].reshape( this_node .shape @@ -463,10 +467,7 @@ impl Context { covered_ops.insert(*dependent_op); } } - Operation::ReduceMax { - node, - dim, - } => { + Operation::ReduceMax { node, dim } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reduce_max(&[dim], false)?; @@ -476,10 +477,7 @@ impl Context { covered_ops.insert(*dependent_op); } } - Operation::ReduceSum { - node, - dim, - } => { + Operation::ReduceSum { node, dim } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reduce_sum(&[dim], false)?; @@ -489,10 +487,7 @@ impl Context { covered_ops.insert(*dependent_op); } } - Operation::ReduceMean { - node, - dim, - } => { + Operation::ReduceMean { node, dim } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reduce_mean(&[dim], false)?; @@ -502,10 +497,7 @@ impl Context { covered_ops.insert(*dependent_op); } } - Operation::ReduceArgmax { - node, - dim, - } => { + Operation::ReduceArgmax { node, dim } => { if xla_op_slotmap.contains_key(unda_xla_map[&node]) { let xla_op = xla_op_slotmap[unda_xla_map[&node]].reduce_argmax(dim, false)?; diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index f0cd559..a3211ff 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -60,6 +60,9 @@ pub enum ContextError { #[error("Expected real type, got {0}. Real types are F16, Bf16, F32, F64")] RealTypeError(xla::ElementType, Callsite), + #[error("Expected floating point type, got {0}. Real types are F16, Bf16, F32, F64, C64, C128")] + FPTypeError(xla::ElementType, Callsite), + #[error("Expected tensor of rank {0}, got {1}")] RankError(usize, usize, Callsite), diff --git a/src/core/graph/dtypes.rs b/src/core/graph/dtypes.rs new file mode 100644 index 0000000..ced980f --- /dev/null +++ b/src/core/graph/dtypes.rs @@ -0,0 +1,37 @@ +use super::*; + +pub(crate) fn check_fp_type(dtype: xla::ElementType) -> Result { + match dtype { + xla::ElementType::F16 + | xla::ElementType::Bf16 + | xla::ElementType::F32 + | xla::ElementType::F64 + | xla::ElementType::C64 + | xla::ElementType::C128 => Ok(dtype), + _ => Err(ContextError::FPTypeError(dtype, callsite!(1))), + } +} + +pub(crate) fn check_int_type(dtype: xla::ElementType) -> Result { + match dtype { + xla::ElementType::U8 + | xla::ElementType::S8 + | xla::ElementType::U16 + | xla::ElementType::S16 + | xla::ElementType::U32 + | xla::ElementType::S32 + | xla::ElementType::U64 + | xla::ElementType::S64 => Ok(dtype), + _ => Err(ContextError::IntegralTypeError(dtype, callsite!(1))), + } +} + +pub(crate) fn check_real_type(dtype: xla::ElementType) -> Result { + match dtype { + xla::ElementType::F16 + | xla::ElementType::Bf16 + | xla::ElementType::F32 + | xla::ElementType::F64 => Ok(dtype), + _ => Err(ContextError::RealTypeError(dtype, callsite!(1))), + } +} \ No newline at end of file diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index fc0b5d7..6ee3491 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -1,5 +1,7 @@ use smallvec::SmallVec; +use self::dtypes::*; + use super::*; impl Context { @@ -517,12 +519,17 @@ impl Context { } pub fn softmax(&mut self, a: NodeIdentifier) -> Result { + let dtype = check_fp_type(self.nodes[a].dtype)?; + let max = self.reduce_max(a, 0, true)?; let stop_grad = self.stop_gradient(max); let unnormalized = self.sub(a, stop_grad)?; let unnormalized_exp = self.exp(unnormalized)?; let sum = self.reduce_sum(unnormalized_exp, 0, true)?; + let eps = self.scalar(1e-8, dtype)?; + // prevent division by 0 + let sum = self.add(sum, eps)?; self.div(unnormalized_exp, sum) } @@ -553,6 +560,7 @@ impl Context { pub fn reshape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { let a_size = self.nodes[a].shape.size(); + println!("reshape: {} {}", self.nodes[a].shape, shape); if a_size != shape.size() { Err(ContextError::ShapeConversion( ShapeConversionError::MismatchedSizes( @@ -783,21 +791,9 @@ impl Context { dense_predictions: NodeIdentifier, sparse_label_vector: NodeIdentifier, ) -> Result { - let converted_labels = match self.nodes[sparse_label_vector].dtype { - xla::ElementType::S64 => sparse_label_vector, - xla::ElementType::U8 - | xla::ElementType::S8 - | xla::ElementType::U16 - | xla::ElementType::S16 - | xla::ElementType::U32 - | xla::ElementType::S32 - | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), - _ => { - return Err(ContextError::IntegralTypeError( - self.nodes[sparse_label_vector].dtype, - callsite!(1), - )) - } + let converted_labels = match check_int_type(self.nodes[sparse_label_vector].dtype) { + Ok(_) => self.type_cast(sparse_label_vector, xla::ElementType::S64), + _ => unreachable!(), }; let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false)?; let compare = self.eq(sparse_predictions, converted_labels)?; @@ -820,31 +816,19 @@ impl Context { } let label_len = self.nodes[sparse_label_vector].shape.sizes[0]; - let converted = match self.nodes[sparse_label_vector].dtype { - xla::ElementType::S64 => sparse_label_vector, - xla::ElementType::U8 - | xla::ElementType::S8 - | xla::ElementType::U16 - | xla::ElementType::S16 - | xla::ElementType::U32 - | xla::ElementType::S32 - | xla::ElementType::U64 => self.type_cast(sparse_label_vector, xla::ElementType::S64), - _ => { - return Err(ContextError::IntegralTypeError( - self.nodes[sparse_label_vector].dtype, - callsite!(1), - )) - } + let converted_labels = match check_int_type(self.nodes[sparse_label_vector].dtype) { + Ok(_) => self.type_cast(sparse_label_vector, xla::ElementType::S64), + _ => unreachable!(), }; let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: Shape::from([label_len, n_classes as u32]), - operation: Operation::OneHot(converted), + operation: Operation::OneHot(converted_labels), dtype: dtype, }); self.dependent_nodes - .entry(converted) + .entry(converted_labels) .or_insert(Vec::new()) .push(node_id); Ok(node_id) @@ -855,7 +839,7 @@ impl Context { prediction_probabilities: NodeIdentifier, one_hot_labels: NodeIdentifier, ) -> Result { - let dtype = self.nodes[prediction_probabilities].dtype; + let dtype = check_real_type(self.nodes[prediction_probabilities].dtype)?; if dtype != self.nodes[one_hot_labels].dtype { return Err(ContextError::IncompatibleOperandTypes( dtype, @@ -863,13 +847,6 @@ impl Context { callsite!(1), )); } - match dtype { - xla::ElementType::F16 - | xla::ElementType::Bf16 - | xla::ElementType::F32 - | xla::ElementType::F64 => {} - _ => return Err(ContextError::RealTypeError(dtype, callsite!(1))), - } let eps = self.scalar(1e-8, dtype)?; // prevent logarithm of zero diff --git a/src/core/graph/mod.rs b/src/core/graph/mod.rs index ad3420e..e4ddeae 100644 --- a/src/core/graph/mod.rs +++ b/src/core/graph/mod.rs @@ -4,6 +4,7 @@ mod compile; mod constant; mod consteval; mod context; +mod dtypes; mod logic; mod math; mod node; From 7696e548f89f1978fd122c13fa92b5c1eda758f9 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Mon, 25 Mar 2024 15:56:05 +0100 Subject: [PATCH 12/21] minor logic issues in math.rs --- src/core/graph/math.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 6ee3491..c58ae09 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -747,19 +747,19 @@ impl Context { dim: i64, keepdims: bool, ) -> Result { - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 != dim { - s.sizes.push(self.nodes[a].shape.sizes[d]) - } - } + let dtype = check_fp_type(self.nodes[a].dtype)?; + + let mut s = self.nodes[a].shape.clone(); + s.sizes.remove(dim as usize); + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, operation: Operation::ReduceMean { node: a, dim }, - dtype: self.nodes[a].dtype, + dtype: dtype, }); self.dependent_nodes.entry(a).or_default().push(node_id); + self.maybe_keepdims(node_id, dim, keepdims) } @@ -793,7 +793,7 @@ impl Context { ) -> Result { let converted_labels = match check_int_type(self.nodes[sparse_label_vector].dtype) { Ok(_) => self.type_cast(sparse_label_vector, xla::ElementType::S64), - _ => unreachable!(), + Err(e) => return Err(e), }; let sparse_predictions = self.reduce_argmax(dense_predictions, 1, false)?; let compare = self.eq(sparse_predictions, converted_labels)?; From 9ec6d2fbc922a78629e0006071a707c6abdcdb48 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Tue, 26 Mar 2024 11:20:57 +0100 Subject: [PATCH 13/21] cannot concat scalars into vectors --- examples/mnist_xla.rs | 2 +- src/core/graph/compile.rs | 14 ------ src/core/graph/consteval.rs | 74 +++++++++++++++----------------- src/core/graph/math.rs | 85 ++++++++++++++++++++++++++++--------- 4 files changed, 101 insertions(+), 74 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index ff5c89d..34aa86d 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -132,7 +132,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result = SlotMap::with_key(); @@ -386,13 +379,6 @@ impl Context { if unda_xla_map.contains_key(&node) && xla_op_slotmap.contains_key(unda_xla_map[&node]) { - println!( - "{} {} {} {}", - self.nodes[node].shape, - this_node.shape, - self.nodes[node].operation, - self.nodes[node].callsite - ); let xla_op = xla_op_slotmap[unda_xla_map[&node]].reshape( this_node .shape diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index 0b321a2..39a8d2c 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -339,6 +339,33 @@ impl Context { Ok(changed) } + // Utility function for handling products with tiled and reshaped constants + // TODO: This should be made to handle arbitrary depth + fn replace_tiled_const(&mut self, a: NodeIdentifier, b: NodeIdentifier, top_level_node: NodeIdentifier) -> Result { + if let Operation::TileInDim { node, n_tiles: _, dim: _ } = self.nodes[a].operation { + if self.nodes[node].is_one()? { + println!("found tile {} {}", self.nodes[b].shape, self.nodes[top_level_node].shape); + let tiled_b = self.tile_to_shape(b, self.nodes[top_level_node].shape.clone())?; + self.replace_index(top_level_node, tiled_b)?; + Ok(true) + } else { + self.replace_tiled_const(node, b, top_level_node) + } + } else if let Operation::Reshape(node) = self.nodes[a].operation { + if self.nodes[node].is_one()? { + println!("found reshape {} {}", self.nodes[b].shape, self.nodes[top_level_node].shape); + let tiled_b = self.tile_to_shape(b, self.nodes[top_level_node].shape.clone())?; + self.replace_index(top_level_node, tiled_b)?; + Ok(true) + } else { + self.replace_tiled_const(node, b, top_level_node) + } + } + else { + Ok(false) + } + } + /// Folds constants in place by replacing any node whose both inputs are Constant /// with a Constant of the result of the operation. All existing references to /// the old node will still point to it once its replaced, and this process is @@ -396,26 +423,6 @@ impl Context { modifications += 1; changed = true; } - // TODO: Clean this up! Too many cases!! - if let Operation::TileInDim { node, n_tiles: _, dim: _ } = self.nodes[a].operation { - if self.nodes[node].is_one()? { - self.replace_index(node_id, b)?; - modifications += 1; - changed = true; - } else if let Operation::Reshape(x) = self.nodes[node].operation { - if self.nodes[x].is_one()? { - self.replace_index(node_id, b)?; - modifications += 1; - changed = true; - } else if let Operation::Reshape(y) = self.nodes[x].operation { - if self.nodes[y].is_one()? { - self.replace_index(node_id, b)?; - modifications += 1; - changed = true; - } - } - } - } if self.nodes[b].is_zero()? { self.replace_index(node_id, b)?; modifications += 1; @@ -426,25 +433,14 @@ impl Context { modifications += 1; changed = true; } - if let Operation::TileInDim { node, n_tiles: _, dim: _ } = self.nodes[b].operation { - if self.nodes[node].is_one()? { - self.replace_index(node_id, a)?; - modifications += 1; - changed = true; - } else if let Operation::Reshape(x) = self.nodes[node].operation { - if self.nodes[x].is_one()? { - self.replace_index(node_id, a)?; - modifications += 1; - changed = true; - } else if let Operation::Reshape(y) = self.nodes[x].operation { - if self.nodes[y].is_one()? { - self.replace_index(node_id, a)?; - modifications += 1; - changed = true; - } - } - } - } + if self.replace_tiled_const(a, b, node_id)? { + modifications += 1; + changed = true; + }; + if self.replace_tiled_const(b, a, node_id)? { + modifications += 1; + changed = true; + }; if self.nodes[a].is_const().is_none() { to_visit.push(a); } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index c58ae09..aea5849 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -560,7 +560,6 @@ impl Context { pub fn reshape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { let a_size = self.nodes[a].shape.size(); - println!("reshape: {} {}", self.nodes[a].shape, shape); if a_size != shape.size() { Err(ContextError::ShapeConversion( ShapeConversionError::MismatchedSizes( @@ -625,14 +624,9 @@ impl Context { stride: i64, dim: i64, ) -> Result { - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 == dim { - s.sizes.push(((stop - start) / stride) as u32); - } else { - s.sizes.push(self.nodes[a].shape.sizes[d]); - } - } + let mut s = self.nodes[a].shape.clone(); + s.sizes[dim as usize] = ((start - stop) / stride) as u32; + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -656,7 +650,12 @@ impl Context { dim: i64, ) -> Result { let mut s = self.nodes[a].shape.clone(); - s.sizes[dim as usize] *= n_tiles as u32; + if s.sizes.is_empty() { + s.sizes.push(n_tiles as u32); + } else { + s.sizes[dim as usize] *= n_tiles as u32; + } + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -671,6 +670,48 @@ impl Context { Ok(node_id) } + // Utility function for tiling a small tensor to a larger shape + // when it is known that the smaller tensor's shape broadcasts to the larger shape + pub fn tile_to_shape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { + let node_a_shape = self.nodes[a].shape.clone(); + + if node_a_shape == shape { + return Ok(a); + } + + match node_a_shape.broadcast(&shape) { + None => Err(ContextError::IncompatibleOperandShapes( + node_a_shape, + shape.clone(), + callsite!(1), + )), + Some(s) => { + if s.size() > shape.size() { + return Err(ContextError::IncompatibleOperandShapes( + node_a_shape, + shape.clone(), + callsite!(1), + )); + } + if node_a_shape.sizes.is_empty() { + let mut tiled = a; + for d in (0..s.ndims()).rev() { + tiled = self.tile_in_dim(tiled, s.sizes[d] as i64, 0)?; + } + Ok(tiled) + } else { + let mut tiled = a; + for d in 0..s.ndims() { + if node_a_shape.sizes[d] == 1 { + tiled = self.tile_in_dim(tiled, s.sizes[d] as i64, d as i64)?; + } + } + Ok(tiled) + } + } + } + } + pub fn zeros_like(&mut self, a: NodeIdentifier) -> NodeIdentifier { let node_id = self.nodes.insert(Node { callsite: callsite!(1), @@ -703,12 +744,12 @@ impl Context { dim: i64, keepdims: bool, ) -> Result { - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 != dim { - s.sizes.push(self.nodes[a].shape.sizes[d]) - } + let mut s = self.nodes[a].shape.clone(); + if s.sizes.is_empty() { + return Ok(a) } + s.sizes.remove(dim as usize); + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -725,12 +766,12 @@ impl Context { dim: i64, keepdims: bool, ) -> Result { - let mut s = Shape::new(); - for d in (0..self.nodes[a].shape.ndims()).rev() { - if d as i64 != dim { - s.sizes.push(self.nodes[a].shape.sizes[d]) - } + let mut s = self.nodes[a].shape.clone(); + if s.sizes.is_empty() { + return Ok(a) } + s.sizes.remove(dim as usize); + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, @@ -750,6 +791,9 @@ impl Context { let dtype = check_fp_type(self.nodes[a].dtype)?; let mut s = self.nodes[a].shape.clone(); + if s.sizes.is_empty() { + return Ok(a) + } s.sizes.remove(dim as usize); let node_id = self.nodes.insert(Node { @@ -771,6 +815,7 @@ impl Context { ) -> Result { let mut s = self.nodes[a].shape.clone(); s.sizes.remove(dim as usize); + let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, From b4203f2554aae3ec83fd0759a1355baddd561112 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Tue, 26 Mar 2024 11:34:23 +0100 Subject: [PATCH 14/21] fixed tiling scalars --- examples/mnist_xla.rs | 6 ++++-- src/core/graph/consteval.rs | 2 -- src/core/graph/math.rs | 24 +++++++++++++++--------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 34aa86d..5b9fac6 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -244,12 +244,14 @@ fn main() { let test_labels = File::open(test_lbl_path).expect("Failed to open training label file"); println!("Building model and optimizer . . ."); + let now = Instant::now(); let executable = build_model_and_optimizer(&client).expect("Failed to build model and optimizer"); + println!("Finished build in {:.2?}", now.elapsed()); let (mut w1, mut b1, mut w2, mut b2, mut w3, mut b3, mut w_out, mut b_out) = init_params(); - println!("Beginning training."); + println!("Training model . . ."); let now = Instant::now(); for epoch in 0..EPOCHS { let mut train_accuracy = 0f32; @@ -320,7 +322,7 @@ fn main() { train_accuracy / 600f32 ); } - println!("Finished training after {:.2?}", now.elapsed()); + println!("Finished training in {:.2?}", now.elapsed()); // ABSTRACT API REQUIREMENT 7: Serialization // The model is not worth very much if it disappears after our training loop. diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index 39a8d2c..291ff90 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -344,7 +344,6 @@ impl Context { fn replace_tiled_const(&mut self, a: NodeIdentifier, b: NodeIdentifier, top_level_node: NodeIdentifier) -> Result { if let Operation::TileInDim { node, n_tiles: _, dim: _ } = self.nodes[a].operation { if self.nodes[node].is_one()? { - println!("found tile {} {}", self.nodes[b].shape, self.nodes[top_level_node].shape); let tiled_b = self.tile_to_shape(b, self.nodes[top_level_node].shape.clone())?; self.replace_index(top_level_node, tiled_b)?; Ok(true) @@ -353,7 +352,6 @@ impl Context { } } else if let Operation::Reshape(node) = self.nodes[a].operation { if self.nodes[node].is_one()? { - println!("found reshape {} {}", self.nodes[b].shape, self.nodes[top_level_node].shape); let tiled_b = self.tile_to_shape(b, self.nodes[top_level_node].shape.clone())?; self.replace_index(top_level_node, tiled_b)?; Ok(true) diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index aea5849..cbc3b5e 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -558,7 +558,12 @@ impl Context { node_id } - pub fn reshape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { + pub fn reshape>( + &mut self, + a: NodeIdentifier, + shape: S, + ) -> Result { + let shape = shape.into(); let a_size = self.nodes[a].shape.size(); if a_size != shape.size() { Err(ContextError::ShapeConversion( @@ -650,6 +655,11 @@ impl Context { dim: i64, ) -> Result { let mut s = self.nodes[a].shape.clone(); + let node = if s.sizes.is_empty() { + self.reshape(a, [1])? + } else { + a + }; if s.sizes.is_empty() { s.sizes.push(n_tiles as u32); } else { @@ -659,11 +669,7 @@ impl Context { let node_id = self.nodes.insert(Node { callsite: callsite!(1), shape: s, - operation: Operation::TileInDim { - node: a, - n_tiles, - dim, - }, + operation: Operation::TileInDim { node, n_tiles, dim }, dtype: self.nodes[a].dtype, }); self.dependent_nodes.entry(a).or_default().push(node_id); @@ -746,7 +752,7 @@ impl Context { ) -> Result { let mut s = self.nodes[a].shape.clone(); if s.sizes.is_empty() { - return Ok(a) + return Ok(a); } s.sizes.remove(dim as usize); @@ -768,7 +774,7 @@ impl Context { ) -> Result { let mut s = self.nodes[a].shape.clone(); if s.sizes.is_empty() { - return Ok(a) + return Ok(a); } s.sizes.remove(dim as usize); @@ -792,7 +798,7 @@ impl Context { let mut s = self.nodes[a].shape.clone(); if s.sizes.is_empty() { - return Ok(a) + return Ok(a); } s.sizes.remove(dim as usize); From 3fb9497b06334a4f3e229421fd258213bdc2cef3 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Tue, 26 Mar 2024 12:13:33 +0100 Subject: [PATCH 15/21] adapted autodiff to handle broadcasted operands --- examples/mnist_xla.rs | 4 ++-- src/core/graph/autodiff.rs | 47 +++++++++++++++++++++++++++---------- src/core/graph/context.rs | 3 +++ src/core/graph/math.rs | 48 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 86 insertions(+), 16 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 5b9fac6..2ac8d66 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -213,8 +213,8 @@ fn load_mnist_batch( Err(_) => panic!("Failed to reshape MNIST image batch!"), }; - let mut label_bytes = [0u8; 100 * 4]; - labels.read_exact_at(&mut label_bytes, 8 + 100 * 4 * batch_idx)?; + let mut label_bytes = [0u8; 100]; + labels.read_exact_at(&mut label_bytes, 8 + 100 * batch_idx)?; let labels_xla = xla::Literal::vec1(&label_bytes); Ok((images_xla, labels_xla)) diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index 003b927..807ce4f 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -102,6 +102,9 @@ impl Context { Operation::Reshape(node) => { let next_pullback = self.diff(output, dependent_node)?; let node_sh = self.nodes[node].shape.clone(); + let pullback = self.reshape(next_pullback, node_sh)?; + dependent_pullbacks.push(pullback); + /* if self.nodes[next_pullback].shape.size() == node_sh.size() { let pullback = self.reshape(next_pullback, node_sh)?; dependent_pullbacks.push(pullback); @@ -133,6 +136,7 @@ impl Context { callsite!(1), )); } + */ } Operation::Transpose(a, p) => { @@ -150,19 +154,27 @@ impl Context { Operation::Add(a, b) => { let next_pullback = self.diff(output, dependent_node)?; if a == with_respect_to { + let next_pullback = + self.sum_to_shape(next_pullback, self.nodes[a].shape.clone())?; dependent_pullbacks.push(next_pullback); } if b == with_respect_to { + let next_pullback = + self.sum_to_shape(next_pullback, self.nodes[b].shape.clone())?; dependent_pullbacks.push(next_pullback); } } Operation::Sub(a, b) => { + let next_pullback = self.diff(output, dependent_node)?; if a == with_respect_to { - dependent_pullbacks.push(self.diff(output, dependent_node)?); + let next_pullback = + self.sum_to_shape(next_pullback, self.nodes[a].shape.clone())?; + dependent_pullbacks.push(next_pullback); } if b == with_respect_to { - let next_pullback = self.diff(output, dependent_node)?; + let next_pullback = + self.sum_to_shape(next_pullback, self.nodes[b].shape.clone())?; dependent_pullbacks.push(self.neg(next_pullback)); } } @@ -171,14 +183,21 @@ impl Context { let next_pullback = self.diff(output, dependent_node)?; if a == b && a == with_respect_to { let two = self.scalar(2, wrt_dtype)?; - let mul = self.mul(two, a)?; - dependent_pullbacks.push(self.mul(mul, next_pullback)?); + let mul2 = self.mul(two, a)?; + let this_pullback = self.mul(mul2, next_pullback)?; + let this_pullback = + self.sum_to_shape(this_pullback, self.nodes[a].shape.clone())?; + dependent_pullbacks.push(this_pullback); } else if a == with_respect_to { - let mul = self.mul(next_pullback, a)?; - dependent_pullbacks.push(mul); + let this_pullback = self.mul(next_pullback, b)?; + let this_pullback = + self.sum_to_shape(this_pullback, self.nodes[a].shape.clone())?; + dependent_pullbacks.push(this_pullback); } else if b == with_respect_to { - let mul = self.mul(a, next_pullback)?; - dependent_pullbacks.push(mul); + let this_pullback = self.mul(next_pullback, a)?; + let this_pullback = + self.sum_to_shape(this_pullback, self.nodes[b].shape.clone())?; + dependent_pullbacks.push(this_pullback); } } @@ -209,18 +228,23 @@ impl Context { Operation::Div(a, b) => { let next_pullback = self.diff(output, dependent_node)?; if a == with_respect_to { - let div = self.div(next_pullback, b)?; - dependent_pullbacks.push(div); + let this_pullback = self.div(next_pullback, b)?; + let this_pullback = + self.sum_to_shape(this_pullback, self.nodes[a].shape.clone())?; + dependent_pullbacks.push(this_pullback); } if b == with_respect_to { let mul = self.mul(b, b)?; let div = self.div(a, mul)?; let neg = self.neg(div); let this_pullback = self.mul(neg, next_pullback)?; + let this_pullback = + self.sum_to_shape(this_pullback, self.nodes[b].shape.clone())?; dependent_pullbacks.push(this_pullback); } } + // TODO: handle potentially broadcasted operands here Operation::Pow(a, b) => { let next_pullback = self.diff(output, dependent_node)?; if a == with_respect_to { @@ -270,8 +294,7 @@ impl Context { let mut new_shape = self.nodes[node].shape.clone(); new_shape.sizes.insert(dim as usize, n_tiles as u32); - let reshaped_pullback = - self.reshape(next_pullback, new_shape)?; + let reshaped_pullback = self.reshape(next_pullback, new_shape)?; dependent_pullbacks.push(self.reduce_sum( reshaped_pullback, dim, diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index a3211ff..8f2e9cb 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -27,6 +27,9 @@ pub enum ContextError { #[error("Mismatched types {0} {1} at {2}")] IncompatibleOperandTypes(xla::ElementType, xla::ElementType, Callsite), + #[error("Expected shape {0} to have size greater than shape {1} at {2}")] + ExpectedGreaterSize(Shape, Shape, Callsite), + #[error("Tried to call reshape_const on non-constant node at {0}")] NonConstantReshape(Callsite), diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index cbc3b5e..f51c764 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -678,6 +678,7 @@ impl Context { // Utility function for tiling a small tensor to a larger shape // when it is known that the smaller tensor's shape broadcasts to the larger shape + // This utility is very handy for dealing with tiled constants in fold_consts pub fn tile_to_shape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { let node_a_shape = self.nodes[a].shape.clone(); @@ -693,9 +694,9 @@ impl Context { )), Some(s) => { if s.size() > shape.size() { - return Err(ContextError::IncompatibleOperandShapes( - node_a_shape, + return Err(ContextError::ExpectedGreaterSize( shape.clone(), + s.clone(), callsite!(1), )); } @@ -788,6 +789,49 @@ impl Context { self.maybe_keepdims(node_id, dim, keepdims) } + // Utility function for summing a large tensor to a smaller shape + // when it is known that the smaller shape broadcasts to the larger tensor's shape + // This utility is very handy for dealing with broadcasted operands in autodiff + pub fn sum_to_shape(&mut self, a: NodeIdentifier, shape: Shape) -> Result { + let node_a_shape = self.nodes[a].shape.clone(); + + if node_a_shape == shape { + return Ok(a); + } + + match node_a_shape.broadcast(&shape) { + None => Err(ContextError::IncompatibleOperandShapes( + node_a_shape, + shape.clone(), + callsite!(1), + )), + Some(s) => { + if shape.size() > s.size() { + return Err(ContextError::ExpectedGreaterSize( + s.clone(), + shape.clone(), + callsite!(1), + )); + } + if shape.sizes.is_empty() { + let mut summed = a; + for _d in (0..s.ndims()).rev() { + summed = self.reduce_sum(summed, 0, false)?; + } + Ok(summed) + } else { + let mut summed = a; + for d in 0..s.ndims() { + if shape.sizes[d] == 1 { + summed = self.reduce_sum(summed, d as i64, true)?; + } + } + Ok(summed) + } + } + } + } + pub fn reduce_mean( &mut self, a: NodeIdentifier, From cfe2335aea29d7e87330230bd2dd6db8947a5b49 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Wed, 27 Mar 2024 13:08:08 +0100 Subject: [PATCH 16/21] fixed minor build errors --- src/core/graph/operation.rs | 267 ++++++++++++++++++++++-------------- src/core/graph/subterm.rs | 90 +++++++----- 2 files changed, 220 insertions(+), 137 deletions(-) diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index 4d9d85f..b6394ef 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -1,5 +1,8 @@ use super::*; -use std::{fmt::{Display, Formatter, Result}, hash::Hash}; +use std::{ + fmt::{Display, Formatter, Result}, + hash::Hash, +}; use strum_macros::EnumDiscriminants; #[derive(Debug, Clone, EnumDiscriminants)] @@ -23,23 +26,48 @@ pub enum Operation { LessThanEq(NodeIdentifier, NodeIdentifier), GreaterThanEq(NodeIdentifier, NodeIdentifier), - Select{ pred: NodeIdentifier, on_true: NodeIdentifier, on_false: NodeIdentifier }, + Select { + pred: NodeIdentifier, + on_true: NodeIdentifier, + on_false: NodeIdentifier, + }, TypeCast(NodeIdentifier, xla::ElementType), Reshape(NodeIdentifier), Transpose(NodeIdentifier, Vec), MatMul(NodeIdentifier, NodeIdentifier), - SliceInDim{ node: NodeIdentifier, start: i64, stop: i64, stride: i64, dim: i64 }, - TileInDim{ node: NodeIdentifier, n_tiles: i64, dim: i64 }, + SliceInDim { + node: NodeIdentifier, + start: i64, + stop: i64, + stride: i64, + dim: i64, + }, + TileInDim { + node: NodeIdentifier, + n_tiles: i64, + dim: i64, + }, ZerosLike(NodeIdentifier), - ReduceMax{ node: NodeIdentifier, dim: i64, }, - ReduceSum{ node: NodeIdentifier, dim: i64, }, - // TODO: This might not behave well for integral types! Figure out behavior. - ReduceMean{ node: NodeIdentifier, dim: i64, }, - ReduceArgmax{ node: NodeIdentifier, dim: i64, }, + ReduceMax { + node: NodeIdentifier, + dim: i64, + }, + ReduceSum { + node: NodeIdentifier, + dim: i64, + }, + ReduceMean { + node: NodeIdentifier, + dim: i64, + }, + ReduceArgmax { + node: NodeIdentifier, + dim: i64, + }, OneHot(NodeIdentifier), } @@ -47,65 +75,73 @@ pub enum Operation { impl Hash for Operation { fn hash(&self, state: &mut H) { match self { - Self::Add(a, b) - | Self::Mul(a, b) - | Self::Sub(a, b) - | Self::Div(a, b) - | Self::NotEqual(a, b) - | Self::Equal(a, b) - | Self::LessThan(a, b) - | Self::LessThanEq(a, b) - | Self::GreaterThanEq(a, b) - | Self::GreaterThan(a, b) - | Self::Pow(a, b) - | Self::MatMul(a, b) => { - a.hash(state); - b.hash(state); - }, + Self::Add(a, b) + | Self::Mul(a, b) + | Self::Sub(a, b) + | Self::Div(a, b) + | Self::NotEqual(a, b) + | Self::Equal(a, b) + | Self::LessThan(a, b) + | Self::LessThanEq(a, b) + | Self::GreaterThanEq(a, b) + | Self::GreaterThan(a, b) + | Self::Pow(a, b) + | Self::MatMul(a, b) => { + a.hash(state); + b.hash(state); + } Self::TypeCast(a, ty) => { a.hash(state); (ty.primitive_type() as usize).hash(state); - }, + } Self::Constant(a) => { //This is a little silly but it should work tbh //Might want to redo this later a.to_string().hash(state); - }, + } Self::Parameter(a) => { a.hash(state); - }, + } Self::StopGradient(a) => { a.hash(state); - }, - Self::Log(a) - | Self::Exp(a) - | Self::Reshape(a) - | Self::ZerosLike(a) - | Self::Neg(a) => { - a.hash(state); - } - Self::Select { pred, on_true, on_false } => { + } + Self::Log(a) | Self::Exp(a) | Self::Reshape(a) | Self::ZerosLike(a) | Self::Neg(a) => { + a.hash(state); + } + Self::Select { + pred, + on_true, + on_false, + } => { pred.hash(state); on_true.hash(state); on_false.hash(state); - }, - Self::ReduceMax{ node, dim, } - | Self::ReduceMean { node, dim } - | Self::ReduceSum { node, dim } => { - node.hash(state); - dim.hash(state) - }, + } + Self::ReduceMax { node, dim } + | Self::ReduceMean { node, dim } + | Self::ReduceSum { node, dim } + | Self::ReduceArgmax { node, dim } => { + node.hash(state); + dim.hash(state) + } + Self::OneHot(node) => node.hash(state), Self::Transpose(a, dim) => { a.hash(state); dim.hash(state); - }, - Self::SliceInDim { node, start, stop, stride, dim } => { + } + Self::SliceInDim { + node, + start, + stop, + stride, + dim, + } => { node.hash(state); start.hash(state); stop.hash(state); stride.hash(state); dim.hash(state); - }, + } Self::TileInDim { node, n_tiles, dim } => { node.hash(state); n_tiles.hash(state); @@ -120,65 +156,94 @@ impl PartialEq for Operation { return match (&self, &other) { //Order not matering. Ex: 1 + 2 equals 2 + 1, but 1 / 2 doesnt equal 2 /1 so we can //check these separately - (&Self::Mul(a, b), &Self::Mul(c, d)) - | (&Self::Equal(a, b), &Self::Equal(c, d)) - | (&Self::NotEqual(a, b), &Self::NotEqual(c, d)) - | (&Self::Add(a, b), &Self::Add(c, d)) => { - (a == c && b == d) || (a == d && b == c) - } + (&Self::Mul(a, b), &Self::Mul(c, d)) + | (&Self::Equal(a, b), &Self::Equal(c, d)) + | (&Self::NotEqual(a, b), &Self::NotEqual(c, d)) + | (&Self::Add(a, b), &Self::Add(c, d)) => (a == c && b == d) || (a == d && b == c), //Order does matter, so div, sub, pow etc - (&Self::Div(a, b), &Self::Div(c, d)) - | (&Self::Pow(a, b), &Self::Pow(c, d)) - | (&Self::LessThan(a, b), &Self::LessThan(c, d)) - | (&Self::GreaterThan(a, b), &Self::GreaterThan(c, d)) - | (&Self::GreaterThanEq(a, b), &Self::GreaterThanEq(c, d)) - | (&Self::LessThanEq(a, b), &Self::LessThanEq(c, d)) - | (&Self::MatMul(a, b), &Self::MatMul(c, d)) - | (&Self::Sub(a, b), &Self::Sub(c, d)) => { - a == c && b == d - } - (&Self::StopGradient(a), &Self::StopGradient(b)) - | (&Self::Neg(a), &Self::Neg(b)) - | (&Self::ZerosLike(a), &Self::ZerosLike(b)) - | (&Self::Exp(a), &Self::Exp(b)) - | (&Self::Reshape(a), &Self::Reshape(b)) - | (&Self::Log(a), &Self::Log(b)) => { - a == b - } + (&Self::Div(a, b), &Self::Div(c, d)) + | (&Self::Pow(a, b), &Self::Pow(c, d)) + | (&Self::LessThan(a, b), &Self::LessThan(c, d)) + | (&Self::GreaterThan(a, b), &Self::GreaterThan(c, d)) + | (&Self::GreaterThanEq(a, b), &Self::GreaterThanEq(c, d)) + | (&Self::LessThanEq(a, b), &Self::LessThanEq(c, d)) + | (&Self::MatMul(a, b), &Self::MatMul(c, d)) + | (&Self::Sub(a, b), &Self::Sub(c, d)) => a == c && b == d, + (&Self::StopGradient(a), &Self::StopGradient(b)) + | (&Self::Neg(a), &Self::Neg(b)) + | (&Self::ZerosLike(a), &Self::ZerosLike(b)) + | (&Self::Exp(a), &Self::Exp(b)) + | (&Self::Reshape(a), &Self::Reshape(b)) + | (&Self::Log(a), &Self::Log(b)) => a == b, (&Self::Constant(a), &Self::Constant(b)) => a.to_string() == b.to_string(), (&Self::Parameter(a), &Self::Parameter(b)) => a == b, - (&Self::Select { pred, on_true, on_false }, - &Self::Select { pred: pred2, on_true: on_true2, on_false: on_false2 }) => { - pred == pred2 && on_true == on_true2 && on_false == on_false2 - } - (&Self::TypeCast(a, ty), &Self::TypeCast(b, ty2)) => { - a == b && ty == ty2 - } - (&Self::Transpose(a, dim), &Self::Transpose(b, dim2)) => { - a == b && dim == dim2 - } - (&Self::SliceInDim { node, start, stop, stride, dim }, - &Self::SliceInDim { node: node2, start: start2, stop: stop2, stride: stride2, dim: dim2 }) => { - node == node2 && - start == start2 && - stop == stop2 && - stride == stride2 && - dim == dim2 - } - (&Self::TileInDim { node, n_tiles, dim }, - &Self::TileInDim { node: node2, n_tiles: n_tiles2, dim: dim2 }) => { - node == node2 && - n_tiles == n_tiles2 && - dim == dim2 + ( + &Self::Select { + pred, + on_true, + on_false, + }, + &Self::Select { + pred: pred2, + on_true: on_true2, + on_false: on_false2, + }, + ) => pred == pred2 && on_true == on_true2 && on_false == on_false2, + (&Self::TypeCast(a, ty), &Self::TypeCast(b, ty2)) => a == b && ty == ty2, + (&Self::Transpose(a, dim), &Self::Transpose(b, dim2)) => a == b && dim == dim2, + ( + &Self::SliceInDim { + node, + start, + stop, + stride, + dim, + }, + &Self::SliceInDim { + node: node2, + start: start2, + stop: stop2, + stride: stride2, + dim: dim2, + }, + ) => { + node == node2 + && start == start2 + && stop == stop2 + && stride == stride2 + && dim == dim2 } - (&Self::ReduceMax { node, dim }, &Self::ReduceMax { node: node2, dim: dim2 }) - | (&Self::ReduceMean { node, dim }, &Self::ReduceMean { node: node2, dim: dim2 }) - | (&Self::ReduceSum { node, dim }, &Self::ReduceSum { node: node2, dim: dim2 }) => { - node == node2 && - dim == dim2 - } - _ => false - } + ( + &Self::TileInDim { node, n_tiles, dim }, + &Self::TileInDim { + node: node2, + n_tiles: n_tiles2, + dim: dim2, + }, + ) => node == node2 && n_tiles == n_tiles2 && dim == dim2, + ( + &Self::ReduceMax { node, dim }, + &Self::ReduceMax { + node: node2, + dim: dim2, + }, + ) + | ( + &Self::ReduceMean { node, dim }, + &Self::ReduceMean { + node: node2, + dim: dim2, + }, + ) + | ( + &Self::ReduceSum { node, dim }, + &Self::ReduceSum { + node: node2, + dim: dim2, + }, + ) => node == node2 && dim == dim2, + _ => false, + }; } } diff --git a/src/core/graph/subterm.rs b/src/core/graph/subterm.rs index a33a293..92d5e89 100644 --- a/src/core/graph/subterm.rs +++ b/src/core/graph/subterm.rs @@ -16,7 +16,7 @@ impl Context { return Ok(true); } let mut node_map: HashMap = HashMap::new(); - + let mut modifications = 0; let mut changed = false; @@ -27,7 +27,7 @@ impl Context { if visited.contains(&node_id) || modifications >= modification_limit { continue; } - + if node_map.contains_key(&self.nodes[node_id]) { self.replace_index(node_id, node_map[&self.nodes[node_id]])?; modifications += 1; @@ -35,43 +35,59 @@ impl Context { } else { node_map.insert(self.nodes[node_id].clone(), node_id); } - + visited.insert(node_id); //Add operation nodes to the queue match self.nodes[node_id].operation { - Operation::Add(a, b) - | Operation::Sub(a, b) - | Operation::Mul(a, b) - | Operation::Div(a, b) - | Operation::NotEqual(a, b) - | Operation::Equal(a, b) - | Operation::LessThan(a, b) - | Operation::GreaterThan(a, b) - | Operation::GreaterThanEq(a, b) - | Operation::LessThanEq(a, b) - | Operation::MatMul(a, b) - | Operation::Pow(a, b) => { - to_visit.push(a); - to_visit.push(b); - } - Operation::Neg(a) - | Operation::StopGradient(a) - | Operation::Log(a) - | Operation::Exp(a) - | Operation::TypeCast(a, _) - | Operation::Transpose(a, _) - | Operation::SliceInDim { node: a, start: _, stop: _, stride: _, dim: _ } - | Operation::TileInDim { node: a, n_tiles: _, dim: _ } - | Operation::Reshape(a) - | Operation::ZerosLike(a) => { - to_visit.push(a); - } + Operation::Add(a, b) + | Operation::Sub(a, b) + | Operation::Mul(a, b) + | Operation::Div(a, b) + | Operation::NotEqual(a, b) + | Operation::Equal(a, b) + | Operation::LessThan(a, b) + | Operation::GreaterThan(a, b) + | Operation::GreaterThanEq(a, b) + | Operation::LessThanEq(a, b) + | Operation::MatMul(a, b) + | Operation::Pow(a, b) => { + to_visit.push(a); + to_visit.push(b); + } + Operation::Neg(a) + | Operation::StopGradient(a) + | Operation::Log(a) + | Operation::Exp(a) + | Operation::TypeCast(a, _) + | Operation::Transpose(a, _) + | Operation::SliceInDim { + node: a, + start: _, + stop: _, + stride: _, + dim: _, + } + | Operation::TileInDim { + node: a, + n_tiles: _, + dim: _, + } + | Operation::Reshape(a) + | Operation::ZerosLike(a) => { + to_visit.push(a); + } Operation::ReduceMax { node, dim: _ } - | Operation::ReduceMean { node, dim: _ } - | Operation::ReduceSum { node, dim: _ } => { - to_visit.push(node); - } - Operation::Select { pred, on_true, on_false } => { + | Operation::ReduceMean { node, dim: _ } + | Operation::ReduceSum { node, dim: _ } + | Operation::ReduceArgmax { node, dim: _ } => { + to_visit.push(node); + } + Operation::OneHot(node) => to_visit.push(node), + Operation::Select { + pred, + on_true, + on_false, + } => { to_visit.push(pred); to_visit.push(on_true); to_visit.push(on_false); @@ -83,7 +99,9 @@ impl Context { //Recursive recall if we changed something and modifications are still available match changed { false => Ok(false), - true => Ok(changed || self.extract_subterms(outputs, modification_limit - modifications)?) + true => Ok( + changed || self.extract_subterms(outputs, modification_limit - modifications)? + ), } } } From 994446fafd6a4339c41bc6a549900279e1b48565 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Wed, 27 Mar 2024 22:10:18 +0100 Subject: [PATCH 17/21] got rid of clone in training loop --- examples/mnist_xla.rs | 45 ++++++++++++++++--------------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 2ac8d66..112a79d 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -1,12 +1,12 @@ use std::fs::File; use std::io; -use std::time::Instant; use std::os::unix::fs::*; +use std::time::Instant; use unda::core::graph::*; use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; const USE_CPU: bool = false; -const MNIST_DIRECTORY: &str = "/home/ekadile/mnist"; +const MNIST_DIRECTORY: &str = "/home/medusa/mnist"; const EPOCHS: usize = 20; const INIT_LEARNING_RATE: f32 = 1e-3; const LEARNING_RATE_DECAY: f32 = 0.9; @@ -172,6 +172,9 @@ fn init_param(shape: Shape) -> xla::Literal { // are declared in the model context. This becomes insanely hard // to keep track of as the architecture grows, and the user shouldn't // have to worry about it. +// I think the simplest way to achieve this (which is akin to how JAX does it) +// would be to have `Model` objects which are called with two Into> +// structures, one for inputs and one for parameters. fn init_params() -> ( xla::Literal, xla::Literal, @@ -288,7 +291,7 @@ fn main() { let xla_literal = xla_buffer[0][0] .to_literal_sync() .expect("Failed to copy buffer to host"); - let untupled_literals = xla_literal + let mut untupled_literals = xla_literal .to_tuple() .expect("Failed to untuple XLA literals"); @@ -303,17 +306,14 @@ fn main() { // This is really very silly. Because model/optimizer are not separate // we move the weights to the CPU just to move them back - // Even without that, is there a way to get rid of the clone?? - (w1, b1, w2, b2, w3, b3, w_out, b_out) = ( - untupled_literals[2].clone(), - untupled_literals[3].clone(), - untupled_literals[4].clone(), - untupled_literals[5].clone(), - untupled_literals[6].clone(), - untupled_literals[7].clone(), - untupled_literals[8].clone(), - untupled_literals[9].clone() - ); + b_out = untupled_literals.pop().unwrap(); + w_out = untupled_literals.pop().unwrap(); + b3 = untupled_literals.pop().unwrap(); + w3 = untupled_literals.pop().unwrap(); + b2 = untupled_literals.pop().unwrap(); + w2 = untupled_literals.pop().unwrap(); + b1 = untupled_literals.pop().unwrap(); + w1 = untupled_literals.pop().unwrap(); } println!( "Epoch {}: Training loss = {}; Training accuracy = {}", @@ -333,9 +333,8 @@ fn main() { let mut test_loss = 0f32; for batch_idx in 0..100 { - let (test_imgs, test_lbls) = - load_mnist_batch(&test_images, &test_labels, batch_idx) - .expect("Failed to load MNIST batch"); + let (test_imgs, test_lbls) = load_mnist_batch(&test_images, &test_labels, batch_idx) + .expect("Failed to load MNIST batch"); // GOOFY!! // Another consequence of ABSTRACT API REQUIREMENT 4 Not being implemented @@ -345,17 +344,7 @@ fn main() { let xla_buffer = executable .execute(&[ - &test_imgs, - &test_lbls, - &w1, - &b1, - &w2, - &b2, - &w3, - &b3, - &w_out, - &b_out, - &lr, + &test_imgs, &test_lbls, &w1, &b1, &w2, &b2, &w3, &b3, &w_out, &b_out, &lr, ]) .expect("Failed to run PjRt executable"); From 33636fe7a2a9cb56a3f38011cd2584e73391170f Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Wed, 27 Mar 2024 22:34:06 +0100 Subject: [PATCH 18/21] minor cleanups --- examples/mnist_xla.rs | 2 +- src/core/graph/autodiff.rs | 54 ++++---------------------------------- src/core/graph/tests.rs | 4 +-- 3 files changed, 8 insertions(+), 52 deletions(-) diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index 112a79d..c67e610 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -65,7 +65,7 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result Result { let wrt_shape = self.nodes[with_respect_to].shape.clone(); - let wrt_dtype = self.nodes[with_respect_to].dtype; - - if ![ - xla::ElementType::F16, - xla::ElementType::Bf16, - xla::ElementType::F32, - xla::ElementType::F64, - xla::ElementType::C64, - xla::ElementType::C128, - ] - .contains(&wrt_dtype) - { - return Err(ContextError::NonDifferentiableTypeError( - self.nodes[with_respect_to].callsite.clone(), - )); - } + let wrt_dtype = check_fp_type(self.nodes[with_respect_to].dtype)?; + if output == with_respect_to { return self.scalar(1, wrt_dtype); } @@ -104,39 +92,6 @@ impl Context { let node_sh = self.nodes[node].shape.clone(); let pullback = self.reshape(next_pullback, node_sh)?; dependent_pullbacks.push(pullback); - /* - if self.nodes[next_pullback].shape.size() == node_sh.size() { - let pullback = self.reshape(next_pullback, node_sh)?; - dependent_pullbacks.push(pullback); - } else if let Some(_s) = self.nodes[next_pullback] - .shape - .broadcast(&self.nodes[dependent_node].shape) - { - // this is the case where the gradients picked up some - // broadcasted batch dimensions along the way - let mut sum_pullback = next_pullback; - if self.nodes[dependent_node].shape.sizes.is_empty() { - for _ in 0..self.nodes[next_pullback].shape.ndims() { - sum_pullback = self.reduce_sum(next_pullback, 0, false)?; - } - dependent_pullbacks.push(sum_pullback); - } else { - for d in 0..self.nodes[next_pullback].shape.ndims() { - if self.nodes[dependent_node].shape.sizes[d] == 1 { - sum_pullback = - self.reduce_sum(next_pullback, d as i64, true)? - } - } - dependent_pullbacks.push(sum_pullback); - } - } else { - return Err(ContextError::IncompatibleOperandShapes( - self.nodes[next_pullback].shape.clone(), - self.nodes[dependent_node].shape.clone(), - callsite!(1), - )); - } - */ } Operation::Transpose(a, p) => { @@ -204,6 +159,7 @@ impl Context { Operation::MatMul(a, b) => { let next_pullback = self.diff(output, dependent_node)?; + // TODO: Add case for a == b if a == with_respect_to { let mut transpose_dims = Vec::from_iter(0..(self.nodes[b].shape.ndims() as i64)); diff --git a/src/core/graph/tests.rs b/src/core/graph/tests.rs index bfb96d4..a22e6ca 100644 --- a/src/core/graph/tests.rs +++ b/src/core/graph/tests.rs @@ -319,7 +319,7 @@ mod tests { assert_eq!(rust_result[0], f32::exp(1f32)); } - + #[test] fn test_pow() { let mut ctx = Context::new(); @@ -601,7 +601,7 @@ mod tests { let quartic_term = ctx.mul(quarter, x4).expect("quartic_term"); let y = ctx.sub(quartic_term, quadratic_term).expect("y"); - let dydx = ctx.diff(y, x.into()).expect("dydx"); + let dydx = ctx.diff(y, x).expect("dydx"); ctx.fold_consts(dydx, usize::max_value()) .expect("fold consts"); println!("{}", ctx.to_string(dydx)); From 2f50db83f723f2a154cbf50713c0d68da98063a4 Mon Sep 17 00:00:00 2001 From: Eben Kadile Date: Thu, 4 Apr 2024 14:39:09 +0200 Subject: [PATCH 19/21] no more memory leak! --- Cargo.lock | 48 +++++++------- Cargo.toml | 2 +- examples/mnist_xla.rs | 122 +++++++++++++++++++----------------- src/core/graph/math.rs | 4 +- src/core/graph/tests_cpu.rs | 2 +- 5 files changed, 92 insertions(+), 86 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81fe167..2a2b344 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "backtrace" @@ -358,7 +358,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.58", ] [[package]] @@ -490,9 +490,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" @@ -551,9 +551,9 @@ checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "minimal-lexical" @@ -653,9 +653,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] name = "pin-utils" @@ -759,9 +759,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -782,9 +782,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "rustc-demangle" @@ -850,14 +850,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.58", ] [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" dependencies = [ "itoa", "ryu", @@ -932,7 +932,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.53", + "syn 2.0.58", ] [[package]] @@ -954,9 +954,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.53" +version = "2.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" +checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" dependencies = [ "proc-macro2", "quote", @@ -980,7 +980,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.58", ] [[package]] @@ -1029,7 +1029,7 @@ checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.58", ] [[package]] @@ -1157,7 +1157,7 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "xla" version = "0.1.6" -source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#61d7d8af8891eabed91b91990a9d09bddd8e86fd" +source = "git+https://github.com/Ebanflo42/xla-rs?branch=main#30a75ee643b560021805ff90f32c551f9ca26220" dependencies = [ "bindgen", "cc", @@ -1210,9 +1210,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" +version = "2.0.10+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index e7f95dc..25174e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ backtrace = "0.3" smallvec = "1.13" strum = "0.26" strum_macros = "0.26" -xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "dev" } +xla = { git = "https://github.com/Ebanflo42/xla-rs", version = "0.1.6" , branch = "main" } thiserror = "1" half = "2.4.0" byteorder = "1.5" diff --git a/examples/mnist_xla.rs b/examples/mnist_xla.rs index c67e610..d8fca5b 100644 --- a/examples/mnist_xla.rs +++ b/examples/mnist_xla.rs @@ -6,10 +6,12 @@ use unda::core::graph::*; use xla::{ElementType::*, PjRtClient, PjRtLoadedExecutable}; const USE_CPU: bool = false; -const MNIST_DIRECTORY: &str = "/home/medusa/mnist"; -const EPOCHS: usize = 20; +const MEM_FRAC: f64 = 0.9; +const MNIST_DIRECTORY: &str = "/home/ekadile/mnist"; +const EPOCHS: usize = 100; const INIT_LEARNING_RATE: f32 = 1e-3; -const LEARNING_RATE_DECAY: f32 = 0.9; +const LEARNING_RATE_DECAY: f32 = 0.95; +const MIN_LEARNING_RATE: f32 = 1e-5; // ABSTRACT API REQUIREMENT 1: Automatic Layer Construction // We should have functions like this which, for a given layer type, @@ -70,13 +72,13 @@ fn build_model_and_optimizer(client: &xla::PjRtClient) -> Result Result Result Result xla::Literal { // would be to have `Model` objects which are called with two Into> // structures, one for inputs and one for parameters. fn init_params() -> ( - xla::Literal, - xla::Literal, - xla::Literal, - xla::Literal, - xla::Literal, - xla::Literal, + //xla::Literal, + //xla::Literal, + //xla::Literal, + //xla::Literal, + //xla::Literal, + //xla::Literal, xla::Literal, xla::Literal, ) { ( - init_param(Shape::from([28 * 28, 784])), - init_param(Shape::from([1, 784])), - init_param(Shape::from([784, 256])), - init_param(Shape::from([1, 256])), - init_param(Shape::from([256, 64])), - init_param(Shape::from([1, 64])), - init_param(Shape::from([64, 10])), + //init_param(Shape::from([28 * 28, 2000])), + //init_param(Shape::from([1, 2000])), + //init_param(Shape::from([784, 256])), + //init_param(Shape::from([1, 256])), + //init_param(Shape::from([256, 64])), + //init_param(Shape::from([1, 64])), + init_param(Shape::from([784, 10])), init_param(Shape::from([1, 10])), ) } @@ -227,7 +233,7 @@ fn main() { let client = if USE_CPU { PjRtClient::cpu().expect("Failed to construct CPU client") } else { - PjRtClient::gpu(0.9, false).expect("Failed to construct GPU client") + PjRtClient::gpu(MEM_FRAC, false).expect("Failed to construct GPU client") }; let mut train_img_path = MNIST_DIRECTORY.to_owned(); @@ -252,7 +258,7 @@ fn main() { build_model_and_optimizer(&client).expect("Failed to build model and optimizer"); println!("Finished build in {:.2?}", now.elapsed()); - let (mut w1, mut b1, mut w2, mut b2, mut w3, mut b3, mut w_out, mut b_out) = init_params(); + let (mut w_out, mut b_out) = init_params(); println!("Training model . . ."); let now = Instant::now(); @@ -266,7 +272,7 @@ fn main() { .expect("Failed to load MNIST batch"); let lr = - xla::Literal::scalar(INIT_LEARNING_RATE * (LEARNING_RATE_DECAY.powf(epoch as f32))); + xla::Literal::scalar(MIN_LEARNING_RATE.max(INIT_LEARNING_RATE * (LEARNING_RATE_DECAY.powf(epoch as f32)))); // This is where ABSTRACT API REQUIREMENT 5 becomes pertinent // The user should not have to explicitly reference a dozen parameters like this @@ -274,12 +280,12 @@ fn main() { .execute(&[ &train_imgs, &train_lbls, - &w1, - &b1, - &w2, - &b2, - &w3, - &b3, + //&w1, + //&b1, + //&w2, + //&b2, + //&w3, + //&b3, &w_out, &b_out, &lr, @@ -308,12 +314,12 @@ fn main() { // we move the weights to the CPU just to move them back b_out = untupled_literals.pop().unwrap(); w_out = untupled_literals.pop().unwrap(); - b3 = untupled_literals.pop().unwrap(); - w3 = untupled_literals.pop().unwrap(); - b2 = untupled_literals.pop().unwrap(); - w2 = untupled_literals.pop().unwrap(); - b1 = untupled_literals.pop().unwrap(); - w1 = untupled_literals.pop().unwrap(); + //b3 = untupled_literals.pop().unwrap(); + //w3 = untupled_literals.pop().unwrap(); + //b2 = untupled_literals.pop().unwrap(); + //w2 = untupled_literals.pop().unwrap(); + //b1 = untupled_literals.pop().unwrap(); + //w1 = untupled_literals.pop().unwrap(); } println!( "Epoch {}: Training loss = {}; Training accuracy = {}", @@ -344,7 +350,7 @@ fn main() { let xla_buffer = executable .execute(&[ - &test_imgs, &test_lbls, &w1, &b1, &w2, &b2, &w3, &b3, &w_out, &b_out, &lr, + &test_imgs, &test_lbls, &w_out, &b_out, &lr, ]) .expect("Failed to run PjRt executable"); diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index f51c764..ff35b55 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -498,10 +498,10 @@ impl Context { self.maximum(const_zero, a) } - pub fn leaky_relu(&mut self, a: NodeIdentifier) -> Result { + pub fn leaky_relu(&mut self, a: NodeIdentifier, alpha: f32) -> Result { let a_dtype = self.nodes[a].dtype; //TODO: force dtype to be floating point or else this just becomes normal relu - let const_small = self.scalar(0.001, a_dtype)?; + let const_small = self.scalar(alpha, a_dtype)?; let small_x = self.mul(a, const_small)?; self.maximum(small_x, a) diff --git a/src/core/graph/tests_cpu.rs b/src/core/graph/tests_cpu.rs index d4828da..e665671 100644 --- a/src/core/graph/tests_cpu.rs +++ b/src/core/graph/tests_cpu.rs @@ -148,7 +148,7 @@ mod tests { let mut ctx = Context::new(); let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x"); - let lr = ctx.leaky_relu(x).expect("leaky_relu x"); + let lr = ctx.leaky_relu(x, 0.001).expect("leaky_relu x"); let client = xla::PjRtClient::cpu().expect("client"); let name = "test"; From a3e2c7c282eec61e3def9629ea3f84385e2b03a1 Mon Sep 17 00:00:00 2001 From: Braden Everson Date: Thu, 4 Apr 2024 07:52:00 -0500 Subject: [PATCH 20/21] Remove semicolon on eq fn so it returns a bool --- src/core/graph/operation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index fb4d24e..14c9f8d 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -243,7 +243,7 @@ impl PartialEq for Operation { }, ) => node == node2 && dim == dim2, _ => false, - }; + } } } From d77efc0c5c42bd1e5fd3d48e5a9f4dd24028046e Mon Sep 17 00:00:00 2001 From: Braden Everson Date: Thu, 4 Apr 2024 07:58:17 -0500 Subject: [PATCH 21/21] Accidentally removed OneHot and ReduceArgmax from subterm --- src/core/graph/subterm.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/core/graph/subterm.rs b/src/core/graph/subterm.rs index 0f4a035..0bf0a3d 100644 --- a/src/core/graph/subterm.rs +++ b/src/core/graph/subterm.rs @@ -65,6 +65,7 @@ impl Context { } Operation::ReduceMax { node, dim: _ } | Operation::ReduceMean { node, dim: _ } + | Operation::ReduceArgmax { node, dim: _ } | Operation::ReduceSum { node, dim: _ } => { to_visit.push(node); } @@ -73,6 +74,7 @@ impl Context { to_visit.push(on_true); to_visit.push(on_false); } + Operation::OneHot(node) => to_visit.push(node), Operation::Constant(_) | Operation::Parameter(_) => {} } node_map.insert(self.nodes[node_id].clone(), node_id);