Skip to content

Commit

Permalink
Update pytorch.js (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 10, 2025
1 parent f26dc3d commit 372616e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 37 deletions.
161 changes: 125 additions & 36 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -5414,6 +5414,49 @@ python.Execution = class {
throw new python.Error('Not implemented.');
}
});
this.registerFunction('torch._C.EqualNode', (lhs, rhs) => {
if (lhs === null && rhs === null) {
return true;
}
if (lhs === null || rhs === null) {
return false;
}
if (lhs.kind() !== rhs.kind()) {
return false;
}
const lhs_outputs = lhs.outputs();
const rhs_outputs = rhs.outputs();
if (lhs_outputs.length !== rhs_outputs.length) {
return false;
}
for (let i = 0; i < lhs_outputs.length; i++) {
const lt = lhs_outputs[i].type();
const rt = rhs_outputs[i].type();
if (lt !== rt) {
return false;
}
}
const lhs_inputs = lhs.inputs();
const rhs_inputs = rhs.inputs();
if (lhs_inputs.length !== rhs_inputs.length) {
return false;
}
if (!lhs_inputs.every((v, i) => v === rhs_inputs[i])) {
return false;
}
if (!torch._C.attributesEqualCSE(lhs, rhs)) {
return false;
}
if (lhs.blocks().length !== rhs.blocks().length) {
return false;
}
for (let i = 0; i < lhs.blocks().length; i++) {
if (lhs.blocks()[i] !== rhs.blocks()[i]) {
return false;
}
}
return true;
});
this.registerType('torch._C.ConstantPropagator', class {
constructor(graph, aliasing_types, ignore_custom_classes) {
this._made_change = false;
Expand All @@ -5430,6 +5473,24 @@ python.Execution = class {
}
propagateNode(/* n */) {
}
inlineIfBody(body) {
const n = body.owningNode();
for (const body_node of body.nodes()) {
body_node.moveBefore(n);
}
for (let i = 0; i < n.outputs().length; i++) {
n.outputs()[i].replaceAllUsesWith(body.outputs()[i]);
}
n.destroy();
}
inlineIf(n) {
const input_bool = torch._C.constant_as(n.input(), 'toBool');
torch._C.AT_ASSERT(input_bool !== null);
const block_index = input_bool ? 0 : 1;
this.ConstantPropagation(n.blocks()[block_index]);
this.inlineIfBody(n.blocks()[block_index]);
this._made_change = true;
}
removeExtraIfOutputs(n) {
torch._C.TORCH_CHECK(n.kind() === 'prim::If');
const [true_block, false_block] = n.blocks();
Expand Down Expand Up @@ -9283,6 +9344,8 @@ python.Execution = class {
map.set('bool', torch.BoolType.get());
map.set('complex', torch.ComplexType.get());
map.set('str', torch.StringType.get());
map.set('None', torch.NoneType.get());
map.set('NoneType', torch.NoneType.get());
torch._C.string_to_type_lut.basePythonTypes = map;
}
return torch._C.string_to_type_lut.basePythonTypes;
Expand Down Expand Up @@ -9779,8 +9842,7 @@ python.Execution = class {
insertFunctionCall(callee, matched) {
const func_name = callee.name();
const fn_constant = this.insertNode(this.create('prim::Constant')).s_('name', func_name).output().setType(torch._C.FunctionType.create(callee));
const inputs = [fn_constant];
// inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
const inputs = [fn_constant, ...matched.inputs];
const result = this.insertNode(this.create('prim::CallFunction', inputs)).output().setType(matched.return_types[0]);
return result;
}
Expand Down Expand Up @@ -9886,10 +9948,10 @@ python.Execution = class {
nodes() {
const nodes = [];
let current = this._input.next;
do {
while (current !== this._input.prev) {
nodes.push(current);
current = current.next;
} while (current !== this._input.prev);
}
return nodes;
}
return_node() {
Expand Down Expand Up @@ -9960,8 +10022,9 @@ python.Execution = class {
});
this.registerType('torch.Node', class {
constructor(graph, kind) {
this._graph = graph;
this._kind = kind;
this._graph = graph;
this._owning_block = null;
this._values = new Map();
this._inputs = [];
this._outputs = [];
Expand All @@ -9970,6 +10033,7 @@ python.Execution = class {
this._prev = null;
this._next = null;
this._source_range = null;
this._op = null;
}
owningGraph() {
return this._graph;
Expand All @@ -9981,7 +10045,7 @@ python.Execution = class {
return this._kind;
}
schema() {
if (this._op === undefined) {
if (this._op === null) {
this._op = null;
const index = this._kind.indexOf('.');
const name = index === -1 ? this._kind : this._kind.substring(0, index);
Expand Down Expand Up @@ -10143,17 +10207,18 @@ python.Execution = class {
return this._outputs;
}
input(i) {
if (i === undefined && this._inputs.length !== 1) {
throw new python.Error('Node has multiple inputs.');
if (i === undefined) {
torch._C.AT_ASSERT(this._inputs.length === 1);
return this._inputs[0];
}
i = i || 0;
return this._inputs[i];
}
output() {
if (this._outputs.length !== 1) {
throw new python.Error('Node has multiple outputs.');
output(i) {
if (i === undefined) {
torch._C.AT_ASSERT(this._outputs.length === 1);
return this._outputs[0];
}
return this._outputs[0];
return this._outputs[i];
}
hasUses() {
for (const o of this.outputs()) {
Expand All @@ -10167,10 +10232,10 @@ python.Execution = class {
return this._blocks;
}
addInput(value) {
if (this._graph !== value.owningGraph()) {
throw new python.Error('Value not in graph.');
}
value.uses().push(new torch.Use(this, this._inputs.length));
torch._C.AT_ASSERT(this._graph === value.owningGraph());
this._op = null;
const use = new torch.Use(this, this._inputs.length);
value.uses().push(use);
this._inputs.push(value);
return value;
}
Expand Down Expand Up @@ -10204,19 +10269,16 @@ python.Execution = class {
return this;
}
insertAfter(n) {
if (this.inBlockList() || !n.inBlockList() || !n.owningBlock()) {
throw new python.Error('Node not in block.');
}
if (n.kind() === 'prim::Return') {
throw new python.Error('Cannot insert after return.');
}
torch._C.AT_ASSERT(!this.inBlockList() || n.inBlockList());
torch._C.AT_ASSERT(n.owningBlock());
torch._C.TORCH_INTERNAL_ASSERT(n.kind() !== 'prim::Return');
this._owning_block = n.owningBlock();
const next = n.next;
const next = n.next;
n.next = this;
this.prev = n;
this.next = next;
next.prev = this;
// assignTopoPosition();
// this.assignTopoPosition();
return this;
}
allocNewInstance(g) {
Expand All @@ -10238,7 +10300,7 @@ python.Execution = class {
torch._C.AT_ASSERT(i < this._inputs.length);
const input_node = this._inputs[i];
const use_it = this.findUseForInput(i);
input_node._uses.splice(use_it.offset, 1);
input_node._uses = input_node._uses.filter((use) => use !== use_it);
this._inputs[i] = null;
return input_node;
}
Expand Down Expand Up @@ -10286,7 +10348,7 @@ python.Execution = class {
for (let i = 0; i < this._inputs.length; i++) {
this.dropInput(i);
}
this._inputs.splice(0, this._inputs.length);
this._inputs = [];
}
inBlockList() {
return this.next !== null;
Expand Down Expand Up @@ -10440,6 +10502,11 @@ python.Execution = class {
}
return out;
}
toString() {
const out = new io.StringIO();
this.print(out, 0, true);
return out.toString();
}
});
this.registerType('torch.Value', class {
constructor(node) {
Expand Down Expand Up @@ -10532,6 +10599,7 @@ python.Execution = class {
return this._value;
}
replaceFirstUseWith(newValue) {
torch._C.AT_ASSERT(this.owningGraph() === newValue.owningGraph());
const [u] = this.uses();
u.user._inputs[u.offset] = newValue;
newValue._uses.push(u);
Expand All @@ -10550,11 +10618,7 @@ python.Execution = class {
return this;
}
toString() {
const list = [];
list.push(this.debugName());
list.push(' : ');
list.push(this.type().toString());
return list.join('');
return `${this.debugName()} : ${this.type().toString()}`;
}
});
this.registerType('torch.Use', class {
Expand Down Expand Up @@ -11498,9 +11562,7 @@ python.Execution = class {
});
this.registerFunction('torch._C.insertGraph', (g, callee, inputs, value_map) => {
const value_map_func = (v) => value_map.get(v);
if (callee.inputs().length !== inputs.length) {
throw new python.Error('Invalid number of inputs.');
}
torch._C.AT_ASSERT(callee.inputs().length === inputs.length);
for (let i = 0; i < inputs.length; i++) {
value_map.set(callee.inputs()[i], inputs[i]);
}
Expand Down Expand Up @@ -11574,7 +11636,7 @@ python.Execution = class {
loadSource(/* source, the_namespace */) {
}
});
this.registerFunction('torch._C.getAllBuiltinFunctionsFor', () => {
this.registerFunction('torch._C.getAllBuiltinFunctionsFor', (name) => {
torch._C.registry = torch._C.registry || new torch._C.BuiltinFunctionRegistry();
return torch._C.registry.getAllBuiltinFunctionsFor(name);
});
Expand Down Expand Up @@ -12384,6 +12446,14 @@ python.Execution = class {
}
throw new python.Error('Unsupported constant literal.');
});
this.registerFunction('torch._C.constant_as', (v /*, target */) => {
const ivalue = torch._C.toIValue(v);
if (ivalue !== undefined) {
return ivalue;
// return ivalue[target]();
}
return null;
});
this.registerType('torch._C.NamedValue', class {
constructor(...args) {
if (args.length === 1) {
Expand Down Expand Up @@ -12554,6 +12624,9 @@ python.Execution = class {
return new torch._C.SpecialFormValue(form);
}
});
this.registerFunction('torch._C.makeMagic', (name, base) => {
return new torch._C.MagicMethod(name, base);
});
this.registerType('torch._C.BuiltinFunction', class extends torch._C.SugaredValue {
constructor(symbol, self) {
super();
Expand Down Expand Up @@ -13564,6 +13637,20 @@ python.Execution = class {
}
return out;
}
emitUnaryOp(tree, magicMethod, opSymbol) {
const inputs = [tree.operand];
const named_values = this.getNamedValues(inputs, /*maybe_unpack=*/false);
const val = torch._C.asSimple(torch._C.makeMagic(magicMethod, new torch._C.BuiltinFunction(opSymbol, null)).call(tree, this.method, named_values, [], 0));
if (val.node().kind() !== opSymbol) {
return val;
}
const maybe_out_stack = this.runNodeIfInputsAreConstant(val.node());
if (!maybe_out_stack) {
return val;
}
torch._C.TORCH_INTERNAL_ASSERT(maybe_out_stack.length === 1);
return this.graph.insertConstant(maybe_out_stack[0], tree);
}
emitAssignment(stmt) {
if (stmt.targets.length === 1) {
return this.emitSingleAssignment(stmt);
Expand Down Expand Up @@ -13868,6 +13955,8 @@ python.Execution = class {
return this.emitConst(tree);
} else if (tree instanceof ast.List) {
return this.emitListLiteral(tree, type_hint);
} else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub) {
return this.emitUnaryOp(tree, '__neg__', 'aten::neg');
} else if (tree instanceof ast.Tuple) {
const values = this.getValues(tree.elts, /*maybe_unpack=*/true);
return this.graph.insertNode(this.graph.createTuple(values)).output();
Expand Down
2 changes: 1 addition & 1 deletion source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pytorch.Graph = class {
delattr(param_node.outputs()[0], '');
}
for (const node of graph.nodes()) {
if (node.kind() === 'prim::Constant') {
if (node.kind() === 'prim::Constant' && node.hasAttribute('value')) {
const kind = node.kindOf('value');
const value = node[kind]('value');
for (const output of node.outputs()) {
Expand Down

0 comments on commit 372616e

Please sign in to comment.