Skip to content

Commit

Permalink
More Model Module Docs (#2623)
Browse files Browse the repository at this point in the history
* dinov2

* add another example

* ad dinov2reg4

* eva2

* efficientvit

* moondream

* update t5

* update t5

* rwkv

* stable diffusion docs

* add wasm link

* add segment_anything

* adjsut for clippy

* ignore bertdoc

* dinov2 ignore

* update block to be text

* remove the rust blocks for the moment

* bump python to 3.11

* add a setup-python step

* add py311 to test as well
  • Loading branch information
zachcp authored Nov 17, 2024
1 parent a3f200e commit 12d7e7b
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 72 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -35,6 +38,9 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand Down
50 changes: 0 additions & 50 deletions candle-transformers/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,6 @@
//! - Upstream [Github repo](https://github.com/google-research/bert).
//! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
//!
//! ```no_run
//! // for sentence embeddings
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let model = todo!();
//! # let prompt = "Here is a test sentence";
//! let embeddings = model.forward(prompt)?;
//! // Returns tensor of shape [1, 7, 384]
//! println!("{embeddings}");
//! # Ok(())
//! # }
//!
//! // Different models can be loaded using the model ID
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let vb = todo!();
//! # let config = todo!();
//! let model = BertModel::load(vb, &config )?;
//! # Ok(())
//! # }
//!
//! // Gelu approximation
//! // You can get a speedup by configuring the model
//! // to use an approximation of the gelu activation:
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let mut config = todo!();
//! config.hidden_act = HiddenAct::GeluApproximate;
//! # Ok(())
//! # }
//!
//! // Similarities
//! // Bert can compute sentence embeddings which can then be used to calculate
//! // semantic similarities between sentences through cosine similarity scoring.
//! // The sentence embeddings are computed using average pooling across all tokens.
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let model = todo!();
//! let sentence1 = "The new movie is awesome";
//! let sentence2 = "The new movie is so great";
//! let emb1 = model.forward(sentence1)?;
//! let emb2 = model.forward(sentence2)?;
//! # Ok(())
//! # }
//! ```
//!
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
Expand Down
38 changes: 36 additions & 2 deletions candle-transformers/src/models/dinov2.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
//! Implementation of the DINOv2 models from Meta Research.
//!
//! See:
//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
//! This module implements the DINOv2 vision transformer model from Meta AI Research.
//! DINOv2 is a self-supervised learning model that can learn visual features
//! without using any labeled data. See: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
//!
//! ## Running an example with color map and CUDA
//!
//! ```bash
//! cargo run \
//! --features cuda,depth_anything_v2 \
//! --package candle-examples \
//! --example depth_anything_v2 \
//! -- --color-map \
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
//! ```
//!
//! ## Running as an ImageNet classifier
//!
//! The model returns the probability for the image to belong to each of the 1000 ImageNet categories.
//!
//! <div align=center>
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
//! </div>
//!
//! ```bash
//! cargo run \
//! --example dinov2 \
//! --release \
//! -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
//!
//! > mountain bike, all-terrain bike, off-roader: 43.67%
//! > bicycle-built-for-two, tandem bicycle, tandem: 33.20%
//! > crash helmet : 13.23%
//! > unicycle, monocycle : 2.44%
//! > maillot : 2.42%
//! ```
//!
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};

Expand Down
31 changes: 28 additions & 3 deletions candle-transformers/src/models/dinov2reg4.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
//! Implementation of the DINOv2 revision (4 regularization)
//!
//! See:
//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
//! The DINOv2-reg4 model is a variant of DINOv2 that adds 4 regularization tokens to the
//! original architecture. This implementation is specifically trained for plant species
//! classification on the PlantCLEF2024 dataset with 7,806 classes.
//!
//! This code implements the regularization tokens version with 4 regularization tokens.
//! - [Paper](https://arxiv.org/abs/2309.16588). DINOv2: Learning Robust Visual Features without Supervision
//! - [GH Repo](https://github.com/facebookresearch/dinov2)
//!
//! # Example
//!
//! ```bash
//! # Download classes names and a plant picture to identify
//! # see candle/examples/dinov2reg4 for full code.
//!
//! # Perform inference
//! cargo run \
//! --example dinov2reg4 \
//! --release -- \
//! --image <orchid-file>
//!
//! > Orchis simia Lam. : 45.55%
//! > Orchis × bergonii Nanteuil: 9.80%
//! > Orchis italica Poir. : 9.66%
//! > Orchis × angusticruris Franch.: 2.76%
//! > Orchis × bivonae Tod. : 2.54%
//! ```
//!
//! <div align=center>
//! <img src="https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c" alt="" width=320>
//! </div>
//!
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
Expand Down
37 changes: 34 additions & 3 deletions candle-transformers/src/models/efficientvit.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
//! EfficientViT (MSRA) inference implementation based on timm.
//!
//! See ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027)
//! This crate provides an implementation of the EfficientViT model from Microsoft Research Asia
//! for efficient image classification. The model uses cascaded group attention modules
//! to achieve strong performance while maintaining low memory usage.
//!
//! The model was originally described in the paper:
//! ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027)
//!
//! This implementation is based on the reference implementation from
//! [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py).
//!
//! # Example Usage
//!
//! This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
//! The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
//!
//!
//! ```bash
//! cargo run
//! --example efficientvit \
//! --release -- \
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
//!
//! > loaded image Tensor[dims 3, 224, 224; f32]
//! > model built
//! > mountain bike, all-terrain bike, off-roader: 69.80%
//! > unicycle, monocycle : 13.03%
//! > bicycle-built-for-two, tandem bicycle, tandem: 9.28%
//! > crash helmet : 2.25%
//! > alp : 0.46%
//! ```
//!
//! <div align=center>
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
//! </div>
//!
//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py)
use candle::{Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
Expand Down
28 changes: 25 additions & 3 deletions candle-transformers/src/models/eva2.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,31 @@
//! EVA-2 inference implementation.
//!
//! See ["EVA-02: A Visual Representation for Neon Genesis"](https://arxiv.org/abs/2303.11331)
//! EVA-02 is a computer vision model that can be used as an ImageNet classifier.
//! The model returns the probability for an image to belong to each of the 1000
//! ImageNet categories.
//!
//! - [Paper](https://arxiv.org/abs/2303.11331). EVA-02: A Visual Representation for Neon Genesis
//! - [Code](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py)
//!
//! # Example
//!
//! ```bash
//! cargo run \
//! --example eva2 \
//! --release -- \
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
//!
//! > mountain bike, all-terrain bike, off-roader: 37.09%
//! > maillot : 8.30%
//! > alp : 2.13%
//! > bicycle-built-for-two, tandem bicycle, tandem: 0.84%
//! > crash helmet : 0.73%
//! ```
//!
//! <div align=center>
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
//! </div>
//!
//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py)
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};

Expand Down
30 changes: 28 additions & 2 deletions candle-transformers/src/models/moondream.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
//! MoonDream Model vision-to-text
//!
//!
//! Moondream is a computer-vision model that can answer real-world questions about images.
//! It's lightweight with only 1.6B parameters, enabling it to run on mobile phones and edge devices.
//! [MoonDream Original Implementation](https://github.com/vikhyat/moondream)
//!
//! The model consists of:
//! - Vision encoder using a ViT-style architecture
//! - Text decoder based on Microsoft's Phi model
//! - Vision projection module to align vision and text embeddings
//!
//! References:
//! - [MoonDream Original Implementation](https://github.com/vikhyat/moondream)
//! # Examples
//!
//! <img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
//!
//! ```bash
//! # download an example image
//! wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
//!
//! # Now you can run Moondream from the `candle-examples` crate:
//! cargo run --example moondream \
//! --release -- \
//! --prompt "What is the girl eating?"
//! --image "./demo-1.jpg"
//!
//! > avavx: false, neon: true, simd128: false, f16c: false
//! > temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
//! > retrieved the files in 3.395583ms
//! > Running on CPU, to run on GPU(metal), build this example with `--features metal`
//! > loaded the model in 5.485493792s
//! > loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
//! > starting the inference loop
//! > The girl is eating a hamburger.<
//! > 9 tokens generated (0.68 token/s)
//! ```
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
Expand Down
20 changes: 18 additions & 2 deletions candle-transformers/src/models/rwkv_v5.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! RWKV v5 model implementation.
//!
//! RWKV is an RNN with transformer-level performance that can be implemented
//! as either a transformer or RNN.
//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
//! with performance on par with transformer architectures. Several variants are
//! available, candle implements the v5 and v6 versions and can be used with
//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
//!
//! Key characteristics:
//! - Time-mix attention mechanism
Expand All @@ -14,6 +16,20 @@
//! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM)
//! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main)
//!
//! # Example
//!
//! ```bash
//! cargo run --example rwkv --release -- \
//! --prompt "The smallest prime is "
//!
//! > avx: true, neon: false, simd128: false, f16c: true
//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
//! > The smallest prime is ϕ(2) = 2.
//! > The smallest composite is ϕ(3) = 3.
//! > The smallest perfect number is ϕ(5) = 5.
//! > The smallest perfect square is ϕ(4) = 4.
//! > The smallest perfect cube is ϕ(6) = 6.
//! ```
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor};
Expand Down
21 changes: 17 additions & 4 deletions candle-transformers/src/models/rwkv_v6.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! RWKV v6 model implementation.
//!
//! RWKV is an RNN with transformer-like performance.
//! Version 6 introduces refinements to the architecture.
//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
//! with performance on par with transformer architectures. Several variants are
//! available, candle implements the v5 and v6 versions and can be used with
//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
//!
//! Key characteristics:
//! - Linear attention mechanism
Expand All @@ -10,9 +12,20 @@
//! - Feed forward gating
//! - State recycling for efficient inference
//!
//! References:
//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM)
//! # Example
//!
//! ```bash
//! cargo run --example rwkv --release -- \
//! --prompt "The smallest prime is "
//!
//! > avx: true, neon: false, simd128: false, f16c: true
//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
//! > The smallest prime is ϕ(2) = 2.
//! > The smallest composite is ϕ(3) = 3.
//! > The smallest perfect number is ϕ(5) = 5.
//! > The smallest perfect square is ϕ(4) = 4.
//! > The smallest perfect cube is ϕ(6) = 6.
//! ```
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{IndexOp, Result, Tensor};
Expand Down
29 changes: 26 additions & 3 deletions candle-transformers/src/models/segment_anything/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
//! Segment Anything Model (SAM)
//!
//! SAM is an architecture for image segmentation, capable of segmenting any object
//! in an image based on prompts like points or boxes.
//! in an image based on prompts like points or boxes. //! This model provides a robust and fast image segmentation pipeline that can be tweaked via
//! some prompting (requesting some points to be in the target mask, requesting some
//! points to be part of the background so _not_ in the target mask, specifying some
//! bounding box).
//!
//! - [GH Link](https://github.com/facebookresearch/segment-anything)
//! - [Paper](https://arxiv.org/abs/2304.02643)
//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/candle-segment-anything-wasm)
//! - 💻 [GH Link](https://github.com/facebookresearch/segment-anything)
//! - 📝 [Paper](https://arxiv.org/abs/2304.02643)
//! - 💡 The default backbone can be replaced by the smaller and faster TinyViT model based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
//!
//!
//! ## Example
//!
//! ```bash
//! cargo run --example segment-anything --release -- \
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
//! --use-tiny --point 0.6,0.6 --point 0.6,0.55
//! ```
//!
//! <div align=center style="display: flex; justify-content: center; gap: 10px;">
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width="30%">
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg" alt="" width="30%">
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg" alt="" width="30%">
//! </div>
//!
//!
//! > Original; Prompt with `--point 0.6,0.55`; Prompt with `--point 0.6,0.6 --point 0.6,0.55`
//!
pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor};
Expand Down
Loading

0 comments on commit 12d7e7b

Please sign in to comment.