Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed instruction::replace() logic. #3553

Merged
merged 6 commits into from
Oct 29, 2024
Merged

Conversation

tcgu-amd
Copy link
Contributor

@tcgu-amd tcgu-amd commented Oct 24, 2024

The previous fix with BFS doesn't fully work in more complex cases (e.g. it will fail in the newly added test case check_replace_dag). This fix implements topological sorting to replace instructions in topological order which should work for all cases.

More details:

In a dummy scenario of add2(reduce(x), add1(abs(reduce(x)), sin(reduce(x)))), we will have a dependency tree looking like

reduce _
        \_abs__
         \_sin__\_add1_
          \_____________\_add2

If we call reduce.replace(), BFS will visit the instructions in the following order:

reduce -> abs -> sin -> add2 -> add1

This will causes an error of shape mismatch at add2 because it is called before its input add1.

Topological sorting the instruction tree will yield:

reduce -> sin -> abs -> add1 -> add2

Which is the correct order to process the instructions.

This should be able to extend to more complex cases.

… fully work in more complex cases (e.g. it will fail in the newly added test case check_replace_dag). This fix implements topological sorting to replace instruction in topological order which should work for all cases.
@tcgu-amd tcgu-amd requested a review from causten as a code owner October 24, 2024 18:30
Copy link

codecov bot commented Oct 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.16%. Comparing base (1e1a229) to head (92ebe7f).

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3553   +/-   ##
========================================
  Coverage    92.16%   92.16%           
========================================
  Files          512      512           
  Lines        21401    21408    +7     
========================================
+ Hits         19724    19731    +7     
  Misses        1677     1677           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 24, 2024

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

@tcgu-amd
Copy link
Contributor Author

tcgu-amd commented Oct 24, 2024

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

Yes unfortunately I think this is definitely going to be slower than the previous implementations. I am not quite sure if there's potentially a better approach since we don't know the dependencies of instructions beforehand until after the sort.

One way I can think of is to take an optimistic approach and perform BFS assuming everything is going to be fine, and on shape mismatch just push the instruction to the back of the queue. Only return the error if all instructions in the queue are shape mismatches. This is a little bit unconventional so I will need to test it to make sure it is going to generate correct results.

Edit: Actually, upon further consideration, I think this problem can be solved easily by using a modified version of Kahn's algorithm. I will update the code and try it out.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 24, 2024

There might be a way to traverse up the inputs to check for dependencies. I would need to think about it more.

…ly based on Khan's algorithm.

This version avoids sorting the entire graph, and will terminate when no more changes are requried like old versions
@tcgu-amd
Copy link
Contributor Author

tcgu-amd commented Oct 25, 2024

Hi @pfultz2, I have created a new version of the algorithm that should have the same performance as the old versions.

This is loosely based on Khan's algorithm in that we only process nodes that has been visited by all its children that needs to be replaced.

To achieve this, we perform a BFS from the base instruction as usual, but keep a map counting the number of arguments for each instruction we encounter. If it an instruction is unary, then we can directly process the current instruction. If there's more than one argument, we subtract one from the number of arguments in the map and check to see if the number reaches zero, in which case all of the arguments must have been replaced and we can replace this instruction; otherwise some arguments may still need to be replaced, and we can just skip replacing this instruction for now and wait for it to be encounter again when one of its arguments ultimately adds it back to the queue.

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

Edit:

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

I just realized that there might still be a dependency between the instructions that needs to be partially replaced, and the current version may not be able to capture that..

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 25, 2024

I would think instead you would check if the inputs reaches the instruction and then add that to a revisit queue:

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        std::deque<instruction_ref> q(output.begin(), output.end());
        std::deque<instruction_ref> revisit;
        std::unordered_set<instruction_ref> visited;
        while(not q.empty())
        {
            instruction_ref ins = q.front();
            q.pop_front();
            if(not visited.insert(ins).second)
                continue;
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                for(auto out:ins->outputs())
                {
                    if(any_of(out->inputs(), [&](instruction_ref x) { return x != ins and reaches(ins, x); }))
                    {
                        revisit.push_back(out);
                    }
                    else
                    {
                        q.push_back(ins);
                    }
                }
            }
            if(q.empty())
            {
                q.insert(q.end(), revisit.begin(), revisit.end());
                revisit.clear();
            }
        }
    }
}

This would fix the simple case you presented but I am not sure it would handle more complicated cases.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 25, 2024

Actually, I think it might be much simpler if we just use the order in the instruction list as that should already be in order. So we could just use a priority_queue instead:

struct replace_shape_order
{
    instruction_ref start;

    std::size_t location(instruction_ref x) const
    {
        return std::distance(start, x);
    }

    bool operator()(instruction_ref x, instruction_ref y) const
    {
        return location(x) > location(y);
    }
};

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        auto start = std::find_if(output.front()->inputs().begin(), output.front()->inputs().end(), [&](instruction_ref x) {
            return this == as_address(x);
        });
        assert(as_address(*start) == this);
        std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(output, replace_shape_order{*start});
        while(not q.empty())
        {
            instruction_ref ins = q.top();
            q.pop();
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                std::copy(ins->output.begin(), ins->output.end(), push_inserter(q));
            }
        }
    }
}

@tcgu-amd
Copy link
Contributor Author

Actually, I think it might be much simpler if we just use the order in the instruction list as that should already be in order. So we could just use a priority_queue instead:

struct replace_shape_order
{
    instruction_ref start;

    std::size_t location(instruction_ref x) const
    {
        return std::distance(start, x);
    }

    bool operator()(instruction_ref x, instruction_ref y) const
    {
        return location(x) > location(y);
    }
};

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        auto start = std::find_if(output.front()->inputs().begin(), output.front()->inputs().end(), [&](instruction_ref x) {
            return this == as_address(x);
        });
        assert(as_address(*start) == this);
        std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(output, replace_shape_order{*start});
        while(not q.empty())
        {
            instruction_ref ins = q.top();
            q.pop();
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                std::copy(ins->output.begin(), ins->output.end(), push_inserter(q));
            }
        }
    }
}

This makes sense! Much more elegant too! I will test it out.

@tcgu-amd
Copy link
Contributor Author

@pfultz2 I pushed a commit with the new solution you proposed. Seems like it is working. Worth noting that there is no std::inserter for priority_queue so we can't use std::copy to insert the instruction outputs.

@pfultz2
Copy link
Collaborator

pfultz2 commented Oct 28, 2024

Worth noting that there is no std::inserter for priority_queue so we can't use std::copy to insert the instruction outputs.

You could add one to the migraphx/output_iterator.hpp header:

template <class Container>
auto push_inserter(Container& c)
{
    return make_function_output_iterator([&](const auto& x) { c.push(x); });
}

@causten causten merged commit 58cf599 into ROCm:develop Oct 29, 2024
13 of 20 checks passed
causten added a commit that referenced this pull request Oct 29, 2024
causten added a commit that referenced this pull request Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants