Skip to content

Commit

Permalink
Add an off switch for sampling, fix tense on 'shrinked'
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Nov 3, 2020
1 parent 6835e1a commit d953dbc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
22 changes: 18 additions & 4 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ void CheckpointPool::evict() {
time_t pre = std::chrono::system_clock::now();
STATS.track("CheckpointPool::evict");
TORCH_CHECK(aps.size() > 0);
bool shrinked = false;
// shrunk: either something has been evicted or the pools have gotten smaller
bool shrunk = false;
int evict_idx = -1;
double evict_cost = INFINITY;
time_t current_time = std::chrono::system_clock::now();
Expand All @@ -171,7 +172,7 @@ void CheckpointPool::evict() {
// sampling a random independent subset of all evictable tensors to find the cheapest tensor to evict.
for (size_t i = 0; i < aps.size();) {
auto cannot_evict = [&]() {
shrinked = true;
shrunk = true;
remove_from_aps(i);
};
auto ap_strong = aps[i].lock();
Expand All @@ -189,11 +190,16 @@ void CheckpointPool::evict() {
evict_idx = i;
}
}
i += distrib(gen);

if (sample_tensors) {
i += distrib(gen);
} else {
i += 1;
}
}
}
if (evict_idx == -1) {
TORCH_CHECK(shrinked);
TORCH_CHECK(shrunk);
} else {
auto evict_from_idx = [&](size_t idx) {
auto ap_strong = aps[idx].lock();
Expand Down Expand Up @@ -290,6 +296,14 @@ void set_memory_budget(long budget) {
pool.has_memory_budget = true;
}

void enable_sampling() {
pool.sample_tensors = true;
}

void disable_sampling() {
pool.sample_tensors = false;
}

void reset_profile() {
base_compute_time_ = 0;
remat_compute_time_ = 0;
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ struct CheckpointPool {
std::vector<weak_intrusive_ptr<External>> exts;
std::random_device rd;
std::mt19937 gen = std::mt19937(rd());
// whether to take a square-root sample of the pool during an eviction loop
bool sample_tensors = true;
bool has_memory_budget = false;
long memory_budget;
void evict();
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
- func: unset_memory_budget() -> ()
variants: function

- func: enable_sampling() -> ()
variants: function

- func: disable_sampling() -> ()
variants: function

- func: reset_profile() -> ()
variants: function

Expand Down

0 comments on commit d953dbc

Please sign in to comment.