forked from tobacco-mofs/tobacco_3.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtobacco.py
400 lines (330 loc) · 14 KB
/
tobacco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
from reindex import apply_reindex
from ciftemplate2graph import ct2g
from vertex_edge_assign import vertex_assign, assign_node_vecs2edges
from cycle_cocyle import cycle_cocyle, Bstar_alpha
from bbcif_properties import cncalc, bbelems
from SBU_geometry import SBU_coords
from scale import scale
from scaled_embedding2coords import omega2coords
from place_bbs import scaled_node_and_edge_vectors, place_nodes, place_edges
from remove_net_charge import fix_charges
from remove_dummy_atoms import remove_Fr
from adjust_edges import adjust_edges
from write_cifs import write_check_cif, write_cif, bond_connected_components, distance_search_bond, fix_bond_sym
from scale_animation import scaling_callback_animation, write_scaling_callback_animation, animate_objective_minimization
import configuration
import glob
import os
import re
import networkx as nx
import numpy as np
import sys
import itertools
import time
import traceback
DEBUG=False
def debug_pause():
if DEBUG:
import pdb; pdb.set_trace()
#end def
start_time = time.time()
####### Global options #######
PRINT = configuration.PRINT
ONE_ATOM_NODE_CN = configuration.ONE_ATOM_NODE_CN
CONNECTION_SITE_BOND_LENGTH = configuration.CONNECTION_SITE_BOND_LENGTH
WRITE_CHECK_FILES = configuration.WRITE_CHECK_FILES
WRITE_CIF = configuration.WRITE_CIF
ALL_NODE_COMBINATIONS = configuration.ALL_NODE_COMBINATIONS
USER_SPECIFIED_NODE_ASSIGNMENT = configuration.USER_SPECIFIED_NODE_ASSIGNMENT
COMBINATORIAL_EDGE_ASSIGNMENT = configuration.COMBINATORIAL_EDGE_ASSIGNMENT
CHARGES = configuration.CHARGES
SCALING_ITERATIONS = configuration.SCALING_ITERATIONS
SYMMETRY_TOL = configuration.SYMMETRY_TOL
BOND_TOL = configuration.BOND_TOL
EXPANSIVE_BOND_SEARCH = configuration.EXPANSIVE_BOND_SEARCH
TRACE_BOND_MAKING = configuration.TRACE_BOND_MAKING
NODE_TO_NODE = configuration.NODE_TO_NODE
SINGLE_ATOM_NODE = configuration.SINGLE_ATOM_NODE
ORIENTATION_DEPENDENT_NODES = configuration.ORIENTATION_DEPENDENT_NODES
PLACE_EDGES_BETWEEN_CONNECTION_POINTS = configuration.PLACE_EDGES_BETWEEN_CONNECTION_POINTS
RECORD_CALLBACK = configuration.RECORD_CALLBACK
OUTPUT_SCALING_DATA = configuration.OUTPUT_SCALING_DATA
FIX_UC = configuration.FIX_UC
PRE_SCALE = configuration.PRE_SCALE
SCALING_CONVERGENCE_TOLERANCE = configuration.SCALING_CONVERGENCE_TOLERANCE
SCALING_STEP_SIZE = configuration.SCALING_STEP_SIZE
SINGLE_METAL_MOFS_ONLY = configuration.SINGLE_METAL_MOFS_ONLY
####### Global options #######
pi = np.pi
vname_dict = {'V':1 ,'Er':2 ,'Ti':3 ,'Ce':4 ,'S':5 ,
'H':6 ,'He':7 ,'Li':8 ,'Be':9 ,'B':10 ,
'C':11 ,'N':12 ,'O':13 ,'F':14 ,'Ne':15,
'Na':16,'Mg':17,'Al':18,'Si':19,'P':20 ,
'Cl':21,'Ar':22,'K':23 ,'Ca':24,'Sc':24,
'Cr':26,'Mn':27,'Fe':28,'Co':29,'Ni':30 }
metal_elements = ['Cu', 'Cr', 'Zn', 'Zr', 'Fe', 'Al']
#apply_reindex(CHARGES)
for d in ['templates', 'nodes', 'edges']:
try:
os.remove(os.path.join(d,'.DS_Store'))
except:
pass
if OUTPUT_SCALING_DATA:
scd = open('scaling_data.txt', 'w')
scd.write('ab_ratio_i ab_ratio_f ac_ratio_i ac_ratio_f bc_ratio_i bc_ratio_f alpha_i alpha_f beta_i beta_f gamma_i gamma_f average_covec final_obj\n')
cif_wrapper_list=[]
cif_wrapper_id=0
err_count=0
def write_cif_wrapper(placed_all, fixed_bonds, scaled_params, sc_unit_cell, cifname, charges):
global cif_wrapper_id
cif_wrapper_id+=1
write_cif(placed_all, fixed_bonds, scaled_params, sc_unit_cell, cifname, charges)
#end def
for template in sorted(os.listdir('templates')):
##try catch the errors and procede
try:
#{wrapper of try
topo_cif_id=0 #temporary
print ''
print '========================================================================================================='
print 'template :',template
print '========================================================================================================='
print ''
sys.stdout.flush()
TG, unit_cell, TVT, TET, TNAME, a, b, c, ang_alpha, ang_beta, ang_gamma, max_le = ct2g(template)
#ct2g: return(G, unit_cell, cns, e_types, cifname, a, b, c, alpha, beta, gamma, max_le)
#TG = G#ct2g# = nx.MultiGraph()
node_cns = [(cncalc(node, 'nodes', ONE_ATOM_NODE_CN), node) for node in os.listdir('nodes')]
edge_type_key = dict((list(TET)[k],k) for k in xrange(len(TET)))
print 'Number of vertices = ', len(TG.nodes())
print 'Number of edges = ', len(TG.edges())
print ''
if PRINT:
print 'There are', len(TG.nodes()), 'vertices in the voltage graph:'
print ''
q = 0
for node in TG.nodes():
q += 1
print q,':',node
node_dict = TG.node[node]
print 'type : ', node_dict['type']
print 'cartesian coords : ', node_dict['ccoords']
print 'fractional coords : ', node_dict['fcoords']
print 'degree : ', node_dict['cn'][0]
print ''
print 'There are', len(TG.edges()), 'edges in the voltage graph:'
print ''
for edge in TG.edges(data=True,keys=True):
edge_dict = edge[3]
ind = edge[2]
print ind,':',edge[0],edge[1]
print 'length : ',edge_dict['length']
print 'type : ',edge_dict['type']
print 'label : ',edge_dict['label']
print 'positive direction :',edge_dict['pd']
print 'cartesian coords : ',edge_dict['ccoords']
print 'fractional coords : ',edge_dict['fcoords']
print ''
vas = vertex_assign(TG, TVT, node_cns, unit_cell, ONE_ATOM_NODE_CN, USER_SPECIFIED_NODE_ASSIGNMENT, SYMMETRY_TOL, ALL_NODE_COMBINATIONS)
#vas=return value from vertex_assign = a list of tuple (vertex_topo,node_sbu)
num_edges = len(TG.edges())
CB,CO = cycle_cocyle(TG)
for va in vas:
if len(va) == 0:
print 'At least one vertex does not have a building block with the correct number of connection sites.'
print 'Moving to the next template.'
print ''
continue
if len(CB) != (len(TG.edges()) - len(TG.nodes()) + 1):
print 'The cycle basis is incorrect.'
print 'The number of cycles in the cycle basis does not equal the rank of the cycle space.'
print 'Moving to the next tempate.'
continue
Bstar, alpha = Bstar_alpha(CB,CO,TG,num_edges)
if PRINT:
print 'star (top) and alpha (bottom) for the barycentric embedding are: '
print ''
for i in Bstar:
print i
print ''
for i in alpha:
print i
print ''
omega_plus = np.dot(np.linalg.inv(Bstar),alpha)
num_vertices = len(TG.nodes())
if COMBINATORIAL_EDGE_ASSIGNMENT:
eas = list(itertools.product([e for e in os.listdir('edges')], repeat = len(TET)))
else:
edge_files = sorted([e for e in os.listdir('edges')])
eas = []
i = 0
while len(eas) < len(TET):
eas.append(edge_files[i])
i += 1
if i == len(edge_files):
i = 0
eas = [eas]
g = 0
for va in vas:
node_elems = [bbelems(i[1], 'nodes') for i in va]
metals = [[i for i in j if i in metal_elements] for j in node_elems]
metals = set([i for j in metals for i in j])
v_set = [('v' + str(vname_dict[re.sub('[0-9]','',i[0])]), i[1]) for i in va]
v_set = sorted(list(set(v_set)), key=lambda x: x[0])
v_set = [v[0] + '-' + v[1] for v in v_set]
print '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'
print 'vertex assignment : ',v_set
print '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'
print ''
if len(metals) != 1 and SINGLE_METAL_MOFS_ONLY:
print v_set, 'contains no metals or multiple metal elements, no cif will be written', metals
print ''
continue
for v in va:
for n in TG.nodes(data=True): #see Graph class - If True, return entire node attribute dict as (n, ddict).
if v[0] == n[0]:
n[1]['cifname'] = v[1]
for ea in eas:
g += 1
print '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'
print 'edge assignment : ',ea
print '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'
print ''
###### skip for faster computation
type_assign = dict((k,[]) for k in TET)
for k,m in zip(TET,ea):
type_assign[k] = m
for e in TG.edges(data=True):
ty = e[2]['type']
for k in type_assign:
if ty == k or (ty[1],ty[0]) == k:
e[2]['cifname'] = type_assign[k]
ea_dict = assign_node_vecs2edges(TG, unit_cell, SYMMETRY_TOL)
all_SBU_coords = SBU_coords(TG, ea_dict, CONNECTION_SITE_BOND_LENGTH)
sc_a, sc_b, sc_c, sc_alpha, sc_beta, sc_gamma, sc_covar, Bstar_inv, max_length, callbackresults, ncra, ncca, scaling_data = scale(all_SBU_coords,a,b,c,ang_alpha,ang_beta,ang_gamma,max_le,num_vertices,Bstar,alpha,num_edges,FIX_UC,SCALING_ITERATIONS,PRE_SCALE,SCALING_CONVERGENCE_TOLERANCE,SCALING_STEP_SIZE)
print '*******************************************'
print 'The scaled unit cell parameters are : '
print '*******************************************'
print 'a :', np.round(sc_a, 5)
print 'b :', np.round(sc_b, 5)
print 'c :', np.round(sc_c, 5)
print 'alpha:', np.round(sc_alpha, 5)
print 'beta :', np.round(sc_beta, 5)
print 'gamma:', np.round(sc_gamma, 5)
print ''
for sc, name in zip((sc_a, sc_b, sc_c), ('a', 'b', 'c')):
cflag = False
if sc < 1.0:
print 'unit cell parameter', name, 'has collapsed during scaling!'
print 'try re-running with', name, 'fixed, with a larger value for PRE_SCALE, or with a higher SCALING_CONVERGENCE_TOLERANCE'
print 'no cif will be written'
cflag = True
if cflag:
continue
scaled_params = [sc_a,sc_b,sc_c,sc_alpha,sc_beta,sc_gamma]
sc_Alpha = np.r_[alpha[0:num_edges-num_vertices+1,:], sc_covar]
sc_omega_plus = np.dot(Bstar_inv, sc_Alpha)
ax = sc_a
ay = 0.0
az = 0.0
bx = sc_b * np.cos(sc_gamma * pi/180.0)
by = sc_b * np.sin(sc_gamma * pi/180.0)
bz = 0.0
cx = sc_c * np.cos(sc_beta * pi/180.0)
cy = (sc_c * sc_b * np.cos(sc_alpha * pi/180.0) - bx * cx) / by
cz = (sc_c ** 2.0 - cx ** 2.0 - cy ** 2.0) ** 0.5
sc_unit_cell = np.asarray([[ax,ay,az],[bx,by,bz],[cx,cy,cz]]).T
scaled_coords = omega2coords(TG, sc_omega_plus, (sc_a,sc_b,sc_c,sc_alpha,sc_beta,sc_gamma), num_vertices, template, g, WRITE_CHECK_FILES)
nvecs,evecs = scaled_node_and_edge_vectors(scaled_coords, sc_omega_plus, sc_unit_cell, ea_dict)
placed_nodes, node_bonds = place_nodes(nvecs, CHARGES, ORIENTATION_DEPENDENT_NODES)
placed_edges, edge_bonds = place_edges(evecs, CHARGES, len(placed_nodes))
if RECORD_CALLBACK:
vnames = '_'.join([v.split('.')[0] for v in v_set])
if len(ea) <= 5:
enames = '_'.join([e[0:-4] for e in ea])
else:
enames = str(len(ea)) + '_edges'
prefix = template[0:-4] + '_' + vnames + '_' + enames
frames = scaling_callback_animation(callbackresults, alpha, Bstar_inv, ncra, ncca, num_vertices, num_edges, TG, template, g, False)
write_scaling_callback_animation(frames, prefix)
animate_objective_minimization(callbackresults, prefix)
if PLACE_EDGES_BETWEEN_CONNECTION_POINTS:
placed_edges = adjust_edges(placed_edges, placed_nodes, sc_unit_cell)
placed_all = placed_nodes + placed_edges
bonds_all = node_bonds + edge_bonds
if WRITE_CHECK_FILES:
write_check_cif(template, placed_nodes, placed_edges, g, scaled_params, sc_unit_cell)
if SINGLE_ATOM_NODE or NODE_TO_NODE:
placed_all,bonds_all = remove_Fr(placed_all,bonds_all)
print 'computing X-X bonds...'
print ''
print '*******************************************'
print 'Bond formation : '
print '*******************************************'
sys.stdout.flush()
#check time cost
fixed_bonds, nbcount, bond_check = bond_connected_components(placed_all, bonds_all, sc_unit_cell, max_length, BOND_TOL, TRACE_BOND_MAKING, NODE_TO_NODE, EXPANSIVE_BOND_SEARCH, ONE_ATOM_NODE_CN)
print 'there were ' + str(nbcount) + ' X-X bonds formed'
if bond_check:
print 'bond check passed'
bond_check_code = ''
else:
continue
print 'bond check failed, attempting distance search bonding...'
#check time cost
fixed_bonds, nbcount = distance_search_bond(placed_all, bonds_all, sc_unit_cell, 2.5, TRACE_BOND_MAKING)
bond_check_code = '_BOND_CHECK'
print 'there were', nbcount, 'X-X bonds formed'
print ''
if CHARGES:
fc_placed_all, netcharge, onetcharge, rcb = fix_charges(placed_all)
else:
fc_placed_all = placed_all
#check time cost
fixed_bonds = fix_bond_sym(fixed_bonds, placed_all, sc_unit_cell)
if CHARGES:
print '*******************************************'
print 'Charge information : '
print '*******************************************'
print 'old net charge :', np.round(onetcharge, 5)
print 'new net charge (after rescaling):', np.round(netcharge, 5)
print 'rescaling magnitude :', np.round(rcb, 5)
print ''
vnames = '_'.join([v.split('.')[0] for v in v_set])
if len(ea) <= 5:
enames = []
for e in [e[0:-4] for e in ea]:
if e not in enames:
enames.append(e)
enames = '_'.join(enames)
else:
enames = str(len(ea)) + '_edges'
#altered for shorter cifname
#cifname = template[0:-4] + '_' + vnames + '_' + enames + bond_check_code + '.cif'
topo_cif_id+=1
cifname=template[0:-4] + '_' + str(topo_cif_id) + '.cif'
if OUTPUT_SCALING_DATA:
line = [cifname.split('.')[0]] + [i for j in scaling_data for i in j]
format_string = ' '.join(['{}' for i in line])
scd.write(format_string.format(*line))
scd.write('\n')
if WRITE_CIF:
print 'writing cif...'
print ''
sys.stdout.flush()
debug_pause()
#write out cif file with indexed table
write_cif_wrapper(fc_placed_all, fixed_bonds, scaled_params, sc_unit_cell, cifname, CHARGES)
#write_cif(fc_placed_all, fixed_bonds, scaled_params, sc_unit_cell, cifname, CHARGES)
#}wrapper of try--catch
except Exception as e:
err_count+=1
etype, value, tb = sys.exc_info()
print("-----an error runtime----")
traceback.print_exc()
if OUTPUT_SCALING_DATA:
scd.close()
print 'Normal termination of Tobacco_3.0 after'
print '--- %s seconds ---' % (time.time() - start_time)
print 'total errors: %s' % err_count
print ''