Skip to content

Commit

Permalink
#59 implemented function in the script instead of olds variable param…
Browse files Browse the repository at this point in the history
…eters
  • Loading branch information
AmeZap05 committed Jul 23, 2024
1 parent 45b0a48 commit b739d12
Showing 1 changed file with 53 additions and 51 deletions.
104 changes: 53 additions & 51 deletions sxpat/template_manager/template_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxp
cell=f'({self._specs.lpp}, {self._specs.ppo})',
)

builder.update(connection_constraint="")


class SOPSManager(ProductTemplateManager):

Expand Down Expand Up @@ -723,8 +725,8 @@ def _output_mul_parameter(output_i: int) -> Tuple[str, str]:
def _output_mul_sw(output_i: int) -> Tuple[str, str]:
return f'mult_o{output_i}'
@staticmethod
def _node_connection(from_n: int, from_lv: int, to_n: int, to_lv:int) -> str:
return f'p_con_fn{from_n}_lv{from_lv}_tn{to_n}_lv{to_lv}'
def _node_connection(output_i: int,from_n: int, from_lv: int, to_n: int, to_lv:int) -> str:
return f'p_con_o{output_i}_fn{from_n}_lv{from_lv}_tn{to_n}_lv{to_lv}'

@staticmethod
def _switch_parameter(from_n: int, from_lv: int, to_n: int, to_lv:int):
Expand All @@ -735,8 +737,8 @@ def _input_connection(from_n: int, to_in: int) -> str:
return f'p_con_fn{from_n}_lv0_tin{to_in}'

@staticmethod
def _allow_node_output(output_i: int, nd_i: int, lv_i: int) -> str:
return f'p_allow_o{output_i}_n{nd_i}_lv{lv_i}'
def _allow_node_output(nd_i: int, lv_i: int) -> str:
return f'p_allow_n{nd_i}_lv{lv_i}'

@staticmethod
def _level_parameter(node_i: int,level_i: int):
Expand Down Expand Up @@ -766,7 +768,7 @@ def count_possible_connections(self, connections):
#TODO: Refactor: create classmethods for generate script,(it wourld be more readable)

def _multiplexer_multilevel(self,input_i,input_i__name,output_i,node_i):
return f'Or(Not({self._input_parameters(output_i,input_i,node_i)[0]} == {input_i__name}),{self._input_parameters(output_i,input_i,node_i)[1]})'
return f'\n Or(Not({self._input_parameters(output_i,input_i,node_i)[0]} == {input_i__name}),{self._input_parameters(output_i,input_i,node_i)[1]})'
#"If(p_con_fn"+str(n_gate) +"_lv0_tin"+str(input_i)+", Or(Not(p_i"+str(input_i)+"_l == "+ str(input_i__name) +"),p_i"+str(input_i)+"_s),True)"

def _generate_input(self, output_i,node_i):
Expand All @@ -782,29 +784,18 @@ def _connection_constraints(self, npl, level_i, gate, output_i):
if level_i == 0:
return self._generate_input(output_i,gate)
for node in range(npl[level_i-1]): #param connection from node# to node#s
self._node_connection(node,level_i-1,gate,level_i)
gates_per_level += f'If({self._node_connection(node,level_i-1,gate,level_i)}, If({self._switch_parameter(node,level_i-1,gate,level_i)}, {self._level_parameter(node,level_i-1)}, Not({self._level_parameter(node,level_i-1)})), True),'
#gates_per_level += "If( p_con_fn"+str(node)+"_lv"+str(level_i-1)+"_tn"+str(gate)+"_lv"+str(level_i)+", If(p_sw_fn"+str(node)+"_lv"+str(level_i-1)+"_tn"+str(gate)+"_lv"+str(level_i)+", n"+ str(node) + "_lv" + str(level_i-1) +", Not(n"+str(node)+"_lv"+str(level_i-1)+")), True),"
gates_per_level += f'\nIf({self._node_connection(output_i,node,level_i-1,gate,level_i)}, If({self._switch_parameter(node,level_i-1,gate,level_i)}, {self._level_parameter(node,level_i-1)}(), Not({self._level_parameter(node,level_i-1)}())), True),'

return gates_per_level

def _generate_levels(self,npl,output_i):
gates_per_level = ""

for level_i in range(len(npl)-1, -1, -1):
for gate in range( npl[level_i] ):
if level_i < len(npl) - 1:
gates_per_level += f'\n#level: {level_i}\n {self._level_parameter(gate,level_i)} == Or({self._allow_node_output(output_i,gate,level_i)}, And({self._id_parameter(gate,level_i)} == And({self._connection_constraints(npl, level_i, gate, output_i)}), If({self._neg_parameter(gate,level_i)}, Not({self._id_parameter(gate,level_i)}), {self._id_parameter(gate,level_i)}))),'
#gates_per_level += "\n#level"+str(level_i)+"\n n"+ str(gate) +"_lv"+ str(level_i)+ " == Or( p_allow_o"+str(output_i)+"_n"+str(gate)+"_lv"+str(level_i)+", And(p_id_n"+str(gate)+"_lv"+str(level_i)+" == " + self._connection_constraints(npl,level_i,gate) + "If(p_neg_n"+str(gate)+"_lv"+str(level_i)+", Not(p_id_n"+str(gate)+"_lv"+str(level_i)+"), p_id_n"+str(gate)+"_lv"+str(level_i)+"))),"
else:
gates_per_level+= f'\n#level: {level_i}\n Or({self._allow_node_output(output_i,gate,level_i)}, And({self._id_parameter(gate,level_i)} == And({self._connection_constraints(npl, level_i, gate,output_i)}), If({self._neg_parameter(gate, level_i)}, Not({self._id_parameter(gate,level_i)}), {self._id_parameter(gate, level_i)}))),'
#gates_per_level += "\n#level"+str(level_i)+"\nOr( p_allow_o"+str(output_i)+"_n"+str(gate)+"_lv"+str(level_i)+", And( p_id_n"+str(gate)+"_lv"+str(level_i)+"== " + self._connection_constraints(npl,level_i,gate) + "If(p_neg_n"+str(gate)+"_lv"+str(level_i)+", Not(p_id_n"+str(gate)+"_lv"+str(level_i)+"), p_id_n"+str(gate)+"_lv"+str(level_i)+"))),"
if level_i == len(npl) - 1:
gates_per_level += "),"
gates_per_level = []
for level_i in range(len(npl)):
for gate in range( npl[level_i]):
gates_per_level.append(f'\n#level: {level_i}\n{self._id_parameter(gate,level_i)}() == And({self._connection_constraints(npl, level_i, gate, output_i)}),\n{self._level_parameter(gate,level_i)}() == Or({self._allow_node_output(gate,level_i)}, If({self._neg_parameter(gate,level_i)}, Not({self._id_parameter(gate,level_i)}()), {self._id_parameter(gate,level_i)}())),')

return gates_per_level



def _update_builder(self, builder: Builder) -> None:
# apply superclass updates
super()._update_builder(builder)
Expand All @@ -813,10 +804,10 @@ def _update_builder(self, builder: Builder) -> None:
npl = [None]*self.LV

#initialization gpl
#TODO: reove the + 1
#TODO: remove the + 1
#Amedeo: note that this could be parametrized with different number of gates for each level
for i in range(len(npl)):
npl[i] = 2 #self._specs.pit
npl[i] = 2#self._specs.pit

npl[self.LV - 1] = len(self.subgraph_outputs)

Expand All @@ -835,45 +826,60 @@ def _update_builder(self, builder: Builder) -> None:
),
itertools.chain.from_iterable(
(
self._gen_declare_gate(self._level_parameter(nd,lv)), # n#_lv#
self._gen_declare_gate(self._neg_parameter(nd,lv)), # p_neg_n#_lv#
self._gen_declare_gate(self._id_parameter(nd,lv)) #p_id_n0_lv1
#self._gen_declare_gate(self._level_parameter(nd,lv)), # n#_lv#
#self._gen_declare_gate(self._id_parameter(nd,lv)) # p_id_n#_lv1
)
for lv in range(len(npl))
for nd in range(npl[lv])
),
itertools.chain.from_iterable(
(
self._gen_declare_gate(self._node_connection(f_nd,lv-1,t_nd,lv)), # p_con_fn#_lv#_tn#_lv#
self._gen_declare_gate(self._switch_parameter(f_nd,lv-1,t_nd,lv)),# p_sw_fn#_lv#_tn#_lv#
)
for lv in range(len(npl)-1,0,-1)
for t_nd in range(npl[lv])
for f_nd in range(npl[lv-1])
),
itertools.chain(
self._gen_declare_gate(self._allow_node_output(output_i,nd,lv)) # p_allow_o#_n#_lv#
for output_i in self.subgraph_outputs.keys()
self._gen_declare_gate(self._allow_node_output(nd,lv)) # p_allow_n#_lv#
for lv in range(len(npl))
for nd in range(npl[lv])
),
itertools.chain.from_iterable(
itertools.chain.from_iterable(
(
self._gen_declare_gate((pars := self._input_parameters(output_i,input_i,nd))[0]), # p_o#_i#_n#_l
self._gen_declare_gate(pars[1]) # p_o#_i#_n#_s
)
for output_i in self.subgraph_outputs.keys()
for input_i in self.subgraph_inputs.keys()
for nd in range(npl[0])
),
),
itertools.chain.from_iterable(
(
self._gen_declare_gate(self._node_connection(output_i,f_nd,lv-1,t_nd,lv)), # p_con_o#_fn#_lv#_tn#_lv#
)
for output_i in self.subgraph_outputs.keys()
for lv in range(len(npl)-1,0,-1)
for t_nd in range(npl[lv])
for f_nd in range(npl[lv-1])
),
itertools.chain.from_iterable(
(
self._gen_declare_bool_function(self._level_parameter(nd,lv),0), # function n#_lv#
self._gen_declare_bool_function(self._id_parameter(nd,lv),0), # function p_id_n#_lv1
)
for lv in range(len(npl))
for nd in range(npl[lv])
),
),
)
)

# approximate_wires_constraints
def get_preds(name: str) -> Collection[str]: return sorted(self._current_graph.graph.predecessors(name), key=lambda n: int(re.search(r'\d+', n).group()))
def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxpat_cfg.LABEL]

lines = []
for gate_i, gate_name in self.current_gates.items():

Expand Down Expand Up @@ -902,11 +908,17 @@ def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxp
f'{sxpat_cfg.APPROXIMATE_WIRE_PREFIX}{len(self.inputs) + gate_i}', #that's why we have plus 8 or something in a#
self.subgraph_inputs.values()
)

node = (f'{self._generate_levels(npl, output_i)}')

#TODO: refactor this line of code
lines.append(f'{output_use} == And(And({sxpat_cfg.PRODUCT_PREFIX}{output_i}== Or({node}),\n # output selection \n If(mult_o{output_i},Or(Not(p_o{output_i}_l == {sxpat_cfg.PRODUCT_PREFIX}{output_i}),p_o{output_i}_s),False)),')
lines.append('\n'.join(self._generate_levels(npl, output_i)))
#TODO: constraint at least one connection to the output
lines.append(f'\n {sxpat_cfg.PRODUCT_PREFIX}{output_i} == Or(' + ',\n'.join(
itertools.chain(
f'{self._level_parameter(gate,len(npl)-1)}()'
for gate in range(npl[len(npl)-1])
)
)+ '),')
lines.append(f'{output_use} == Or(Not(p_o{output_i}_l == {sxpat_cfg.PRODUCT_PREFIX}{output_i}),p_o{output_i}_s),')

builder.update(approximate_wires_constraints='\n'.join(lines))

Expand All @@ -920,10 +932,10 @@ def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxp
for nd in range(npl[0])
),
itertools.chain(
f'Implies({", ".join(self._output_mul_parameter(output_i))})'
f'Implies({", ".join(self._output_mul_parameter(output_i))}),'
for output_i in self.subgraph_outputs.keys()
)
))+',')
)))

# p_con_fn#_lv#_tn#_lv#
# total_wpg = []
Expand Down Expand Up @@ -957,10 +969,10 @@ def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxp
)
)
for node_i in range(npl[0])
)if lv == 0 or out_i > 0 else tuple(
)if lv == 0 else tuple(
self._encoding.aggregate_variables(
itertools.chain(
self._node_connection(node_fr,lv-1,node_to,lv)
self._node_connection(out_i,node_fr,lv-1,node_to,lv)
for node_to in range(npl[lv])
)
)
Expand All @@ -976,30 +988,20 @@ def get_func(name: str) -> str: return self._current_graph.graph.nodes[name][sxp
# connection_constraint
builder.update(connection_constraint= '\n'.join(
itertools.chain(
f'Implies({self._node_connection(nd,lv,nd_,lv+1)},{self._allow_node_output(o_i,nd,lv)}),'
f'Implies({self._node_connection(o_i,nd,lv,nd_,lv+1)},Not({self._allow_node_output(nd,lv)})),'
for o_i in self.subgraph_outputs.keys()
for lv in range (len(npl)-1)
for nd in range(npl[lv])
for nd_ in range(npl[lv+1])
)
))

#last level constraint

#####################################################################################################################################################################################################################
# JUST FOR RUN THE TEST

# remove_zero_permutations_constraint
lines = []
""" for output_i in self.subgraph_outputs.keys():
parameters = (
self._product_parameter(output_i, product_i)
for product_i in range(self._specs.ppo)
)
lines.append(f'Implies(Not({self._output_parameter(output_i)}), Not(Or({", ".join(parameters)}))),') """
builder.update(remove_zero_permutations_constraint='\n'.join(lines))

######################################################################################################################################################################################################################

# general informations: benchmark_name, encoding and cell
builder.update(
benchmark_name=self._specs.benchmark_name,
Expand Down

0 comments on commit b739d12

Please sign in to comment.