Skip to content

Commit

Permalink
Add differentiable optimization module. (Meta-Descent, KFO, Meta-Curv…
Browse files Browse the repository at this point in the history
…ature) (#151)

* Ported hypergrad example.

* Add meta-curvature example with GBML wrapper.

* GBML support for nograd, unused, first_order and tests.

* Add ANIL+KFO low-level example.

* Add misc nn layers.

* Update maml_update.

* Change download path for mini-imagenet tests.

* Add docs for differentiable sgd.

* Update docs, incl. for MetaWorld.

* KroneckerTranform docs.

* Docs for meta-curvature.

* Add docs for l2l.nn.misc.

* Add docs for kroneckers.

* Fix lint, add more docs.

* Add docs for GBML.

* Completes GBML docs.

* Rename meta_update -> update_module, and write docs.

* Fix lint, add docs for ParameterUpdate.

* Add docs for LearnableOptimizer

* Update changelog

* Update to readme, part 1

* Update README, part 2.

* Fix readme links

* Version bump.
  • Loading branch information
seba-1511 authored Jul 8, 2020
1 parent 26bfee2 commit 63ff92e
Show file tree
Hide file tree
Showing 36 changed files with 2,079 additions and 111 deletions.
15 changes: 14 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

### Changed

### Fixed


## v0.1.2

### Added

* New example: [Meta-World](https://github.com/rlworkgroup/metaworld) example with MAML-TRPO with it's own env wrapper. (@[Kostis-S-Z](https://github.com/Kostis-S-Z))
* Add l2l.vision.benchmarks interface.
* `l2l.vision.benchmarks` interface.
* Differentiable optimization utilities in `l2l.optim`. (including `l2l.optim.LearnableOptimizer` for meta-descent)
* General gradient-based meta-learning wrapper in `l2l.algorithms.GBML`.
* Various `nn.Modules` in `l2l.nn`.
* `l2l.update_module` as a more general alternative to `l2l.algorithms.maml_update`.

### Changed

Expand Down
173 changes: 117 additions & 56 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,81 +4,142 @@

[![Build Status](https://travis-ci.com/learnables/learn2learn.svg?branch=master)](https://travis-ci.com/learnables/learn2learn)

learn2learn is a PyTorch library for meta-learning implementations.
learn2learn is a software library for meta-learning research.

The goal of meta-learning is to enable agents to *learn how to learn*.
That is, we would like our agents to become better learners as they solve more and more tasks.
For example, the animation below shows an agent that learns to run after a only one parameter update.
learn2learn builds on top of PyTorch to accelerate two aspects of the meta-learning research cycle:

<p align="center"><img src="http://learn2learn.net/assets/img/halfcheetah.gif" height="250px" /></p>
* *fast prototyping*, essential in letting researchers quickly try new ideas, and
* *correct reproducibility*, ensuring that these ideas are evaluated fairly.

**Features**
learn2learn provides low-level utilities and unified interface to create new algorithms and domains, together with high-quality implementations of existing algorithms and standardized benchmarks.
It retains compatibility with [torchvision](https://pytorch.org/vision/), [torchaudio](https://pytorch.org/audio/), [torchtext](https://pytorch.org/text/), [cherry](http://cherry-rl.net/), and any other PyTorch-based library you might be using.

learn2learn provides high- and low-level utilities for meta-learning.
The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms.
The low-level utilities enable researchers to develop new and better meta-learning algorithms.
**Overview**

Some features of learn2learn include:
* [`learn2learn.data`](http://learn2learn.net/docs/learn2learn.data/): `TaskDataset` and transforms to create few-shot tasks from any PyTorch dataset.
* [`learn2learn.vision`](http://learn2learn.net/docs/learn2learn.vision/): Models, datasets, and benchmarks for computer vision and few-shot learning.
* [`learn2learn.gym`](http://learn2learn.net/docs/learn2learn.gym/): Environment and utilities for meta-reinforcement learning.
* [`learn2learn.algorithms`](http://learn2learn.net/docs/learn2learn.algorithms/): High-level wrappers for existing meta-learning algorithms.
* [`learn2learn.optim`](http://learn2learn.net/docs/learn2learn.optim/): Utilities and algorithms for differentiable optimization and meta-descent.

* Modular API: implement your own training loops with our low-level utilities.
* Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE)
* Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry.
* Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification).
* 100% compatible with PyTorch -- use your own modules, datasets, or libraries!
**Resources**

* Website: [http://learn2learn.net/](http://learn2learn.net/)
* Documentation: [http://learn2learn.net/docs/](http://learn2learn.net/docs/)
* Tutorials: [http://learn2learn.net/tutorials/getting_started/](http://learn2learn.net/tutorials/getting_started/)
* Examples: [https://github.com/learnables/learn2learn/tree/master/examples](https://github.com/learnables/learn2learn/tree/master/examples)
* GitHub: [https://github.com/learnables/learn2learn/](https://github.com/learnables/learn2learn/)
* Slack: [http://slack.learn2learn.net/](http://slack.learn2learn.net/)

## Installation

~~~bash
pip install learn2learn
~~~

## API Demo
## Snippets & Examples

The following snippets provide a sneak peek at the functionalities of learn2learn.

### High-level Wrappers

The following is an example of using the high-level MAML implementation on MNIST.
For more algorithms and lower-level utilities, please refer to the [documentation](http://learn2learn.net/docs/learn2learn/) or the [examples](https://github.com/learnables/learn2learn/tree/master/examples).
**Few-Shot Learning with MAML**

For more algorithms (ProtoNets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO) refer to the [examples](https://github.com/learnables/learn2learn/tree/master/examples/vision) folder.
Most of them can be implemented with with the `GBML` wrapper. ([documentation](http://learn2learn.net/docs/learn2learn.algorithms/#gbml)).
~~~python
import learn2learn as l2l

mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)

mnist = l2l.data.MetaDataset(mnist)
train_tasks = l2l.data.TaskDataset(mnist,
task_transforms=[
NWays(mnist, n=3),
KShots(mnist, k=1),
LoadData(mnist),
],
num_tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)

for iteration in range(num_iterations):
learner = maml.clone() # Creates a clone of model
for task in train_tasks:
# Split task in adaptation_task and evalutation_task
# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(adaptation_task)
learner.adapt(error)

# Compute evaluation loss
evaluation_error = compute_loss(evaluation_task)

# Meta-update the model parameters
opt.zero_grad()
evaluation_error.backward()
opt.step()
maml = l2l.algorithms.MAML(model, lr=0.1)
opt = torch.optim.SGD(maml.parameters(), lr=0.001)
for iteration in range(10):
opt.zero_grad()
task_model = maml.clone() # torch.clone() for nn.Modules
adaptation_loss = compute_loss(task_model)
task_model.adapt(adaptation_loss) # computes gradient, update task_model in-place
evaluation_loss = compute_loss(task_model)
evaluation_loss.backward() # gradients w.r.t. maml.parameters()
opt.step()
~~~

## Changelog
**Meta-Descent with Hypergradient**

A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file.
Learn any kind of optimization algorithm with the `LearnableOptimizer`. ([example](https://github.com/learnables/learn2learn/tree/master/examples/optimization) and [documentation](http://learn2learn.net/docs/learn2learn.optim/#learnableoptimizer))
~~~python
linear = nn.Linear(784, 10)
transform = l2l.optim.ModuleTransform(l2l.nn.Scale)
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01) # metaopt has .step()
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001) # metaopt also has .parameters()

metaopt.zero_grad()
opt.zero_grad()
error = loss(linear(X), y)
error.backward()
opt.step() # update metaopt
metaopt.step() # update linear
~~~

### Learning Domains

**Custom Few-Shot Dataset**

Many standardized datasets (Omniglot, mini-/tiered-ImageNet, FC100, CIFAR-FS) are readily available in `learn2learn.vision.datasets`.
([documentation](http://learn2learn.net/docs/learn2learn.vision/#learn2learnvisiondatasets))
~~~python
dataset = l2l.data.MetaDataset(MyDataset()) # any PyTorch dataset
transforms = [ # Easy to define your own transform
l2l.data.transforms.NWays(dataset, n=5),
l2l.data.transforms.KShots(dataset, k=1),
l2l.data.transforms.LoadData(dataset),
]
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
for task in taskset:
X, y = task
# Meta-train on the task
~~~

## Documentation
**Environments and Utilities for Meta-RL**

Documentation and tutorials are available on learn2learn’s website: [http://learn2learn.net](http://learn2learn.net).
Parallelize your own meta-environments with `AsyncVectorEnv`, or use the standardized ones.
([documentation](http://learn2learn.net/docs/learn2learn.gym/#metaenv))
~~~python
def make_env():
env = l2l.gym.HalfCheetahForwardBackwardEnv()
env = cherry.envs.ActionSpaceScaler(env)
return env

env = l2l.gym.AsyncVectorEnv([make_env for _ in range(16)]) # uses 16 threads
for task_config in env.sample_tasks(20):
env.set_task(task) # all threads receive the same task
state = env.reset() # use standard Gym API
action = my_policy(env)
env.step(action)
~~~

### Low-Level Utilities

**Differentiable Optimization**

Learn and differentiate through updates of PyTorch Modules.
([documentation](http://learn2learn.net/docs/learn2learn.optim/#parameterupdate))
~~~python

model = MyModel()
transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear)
learned_update = l2l.optim.ParameterUpdate( # learnable update function
model.parameters(), transform)
clone = l2l.clone_module(model) # torch.clone() for nn.Modules
error = loss(clone(X), y)
updates = learned_update( # similar API as torch.autograd.grad
error,
clone.parameters(),
create_graph=True,
)
l2l.update_module(clone, updates=updates)
loss(clone(X), y).backward() # Gradients w.r.t model.parameters() and learned_update.parameters()
~~~

## Changelog

A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file.

## Citation

Expand All @@ -101,5 +162,5 @@ You can also use the following Bibtex entry.
### Acknowledgements & Friends

1. The RL environments are adapted from Tristan Deleu's [implementations](https://github.com/tristandeleu/pytorch-maml-rl) and from the ProMP [repository](https://github.com/jonasrothfuss/ProMP/). Both shared with permission, under the MIT License.
2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on supervised meta-learning. If learn2learn were missing a particular functionality, we would go check if TorchMeta has it. But we would also open an issue ;)
3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that also enables differentiating through optimization inner-loops. Their approach is different from learn2learn in that they monkey-patch nn.Module to be stateless. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727).
2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on datasets for supervised meta-learning.
3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that enables differentiating through optimization inner-loops. While they monkey-patch `nn.Module` to be stateless, learn2learn retains the stateful PyTorch look-and-feel. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727).
34 changes: 29 additions & 5 deletions docs/pydocmd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ site_name: "learn2learn"
# documented. Higher indentation leads to smaller header size.
generate:
- docs/learn2learn.md:
- learn2learn.utils:
- learn2learn:
- learn2learn.clone_module
- learn2learn.detach_module
- learn2learn.update_module
- learn2learn.magic_box
- docs/learn2learn.data.md:
- learn2learn.data:
Expand All @@ -25,9 +26,8 @@ generate:
- docs/learn2learn.algorithms.md:
- learn2learn.algorithms:
- learn2learn.algorithms.MAML++
- learn2learn.algorithms.maml_update
- learn2learn.algorithms.MetaSGD++
- learn2learn.algorithms.meta_sgd_update
- learn2learn.algorithms.GBML++
- docs/learn2learn.gym.md:
- learn2learn.gym++:
- learn2learn.gym.MetaEnv
Expand All @@ -40,6 +40,27 @@ generate:
- learn2learn.gym.envs.mujoco.HumanoidDirectionEnv
- learn2learn.gym.envs.particles:
- learn2learn.gym.envs.particles.Particles2DEnv
- learn2learn.gym.envs.metaworld:
- learn2learn.gym.envs.metaworld.MetaWorldML1++
- learn2learn.gym.envs.metaworld.MetaWorldML10++
- learn2learn.gym.envs.metaworld.MetaWorldML45++
- docs/learn2learn.optim.md:
- learn2learn.optim++:
- learn2learn.optim.LearnableOptimizer++
- learn2learn.optim.ParameterUpdate++
- learn2learn.optim.DifferentiableSGD++
- learn2learn.optim.transforms:
- learn2learn.optim.transforms.ModuleTransform++
- learn2learn.optim.transforms.KroneckerTransform++
- learn2learn.optim.transforms.MetaCurvatureTransform++
- docs/learn2learn.nn.md:
- learn2learn.nn++:
- learn2learn.nn.Lambda
- learn2learn.nn.Flatten
- learn2learn.nn.Scale
- learn2learn.nn.KroneckerLinear
- learn2learn.nn.KroneckerRNN
- learn2learn.nn.KroneckerLSTM
- docs/learn2learn.vision.md:
- learn2learn.vision++:
- learn2learn.vision.models:
Expand Down Expand Up @@ -73,13 +94,16 @@ pages:
- Feature Reuse with ANIL: tutorials/anil_tutorial/ANIL_tutorial.md
- Documentation:
- learn2learn: docs/learn2learn.md
- learn2learn.algorithms: docs/learn2learn.algorithms.md
- learn2learn.data: docs/learn2learn.data.md
- learn2learn.gym: docs/learn2learn.gym.md
- learn2learn.algorithms: docs/learn2learn.algorithms.md
- learn2learn.optim: docs/learn2learn.optim.md
- learn2learn.nn: docs/learn2learn.nn.md
- learn2learn.vision: docs/learn2learn.vision.md
- learn2learn.gym: docs/learn2learn.gym.md
- Examples:
- Computer Vision: examples.vision.md << ../examples/vision/README.md
- Reinforcement Learning: examples.rl.md << ../examples/rl/README.md
- Optimization: examples.optim.md << ../examples/optimization/README.md
- Changelog: changelog.md << ../CHANGELOG.md
- GitHub: https://github.com/learnables/learn2learn/

Expand Down
22 changes: 22 additions & 0 deletions examples/optimization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Meta-Optimization

This directory contains examples of using learn2learn for meta-optimization or meta-descent.

# Hypergradient

The script `hypergrad_mnist.py` demonstrates how to implement a slightly modified version of "[Online Learning Rate Adaptation with Hypergradient Descent](https://arxiv.org/abs/1703.04782)".
The implementation departs from the algorithm presented in the paper in two ways.

1. We forgo the analytical formulation of the learning rate's gradient to demonstrate the capability of the `LearnableOptimizer` class.
2. We adapt per-parameter learning rates instead of updating a single learning rate shared by all parameters.

**Usage**

!!! warning
The parameters for this script were not carefully tuned.

Manually edit the script and run:

~~~shell
python examples/optimization/hypergrad_mnist.py
~~~
Loading

0 comments on commit 63ff92e

Please sign in to comment.