Skip to content

Commit

Permalink
Loss analysis viz, debug tbplas api and support compute eigenvalues n…
Browse files Browse the repository at this point in the history
…k point each time (#111)

* abacus parse to support upto l=6, get band structure with batched kpoint

* update tbplas and loss viz

* debug to tbplas
  • Loading branch information
floatingCatty authored Apr 8, 2024
1 parent 7ff210a commit a226cfd
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 27 deletions.
2 changes: 1 addition & 1 deletion dptb/data/interfaces/abacus.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def find_target_line(f, target):
with open(os.path.join(output_path, "basis.dat"), 'w') as f:
for atomic_number in element:
counter = Counter(orbital_types_dict[atomic_number])
num_orbs = [counter[i] for i in range(4)] # s, p, d, f
num_orbs = [counter[i] for i in range(6)] # s, p, d, f, g, h
for index_l, l in enumerate(num_orbs):
if l == 0: # no this orbit
continue
Expand Down
30 changes: 19 additions & 11 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,24 @@ def __init__(
self.s_out_field = s_out_field


def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
data = self.h2k(data)
if self.overlap:
data = self.s2k(data)
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
Heff = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
else:
Heff = data[self.h_out_field]

data[self.out_field] = torch.linalg.eigvalsh(Heff)
def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDataDict.Type:
num_k = data[AtomicDataDict.KPOINT_KEY].shape[0]
kpoints = data[AtomicDataDict.KPOINT_KEY]
eigvals = []
if nk is None:
nk = num_k
for i in range(int(np.ceil(num_k / nk))):
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
data = self.h2k(data)
if self.overlap:
data = self.s2k(data)
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
else:
data[self.h_out_field] = data[self.h_out_field]

eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
data[self.out_field] = torch.cat(eigvals, dim=0)

return data
107 changes: 107 additions & 0 deletions dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from e3nn.o3 import Irreps
from torch_scatter import scatter_mean
from dptb.utils.torch_geometric import Batch
import matplotlib.pyplot as plt
from dptb.utils.constants import anglrMId
import re

"""this is the register class for descriptors
Expand Down Expand Up @@ -603,6 +606,110 @@ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict, running_avg:
self.stats["rmse"] = self.stats["rmse"].sqrt()

return self.stats

def visualize(self):
assert hasattr(self, "stats"), "The stats is not computed yet."

with torch.no_grad():
print("Onsite:")
for at, tp in self.idp.chemical_symbol_to_type.items():
print(f"{at}:")
print(f"MAE: {self.stats['onsite'][at]['mae']}")
print(f"RMSE: {self.stats['onsite'][at]['rmse']}")

# compute the onsite per block err
onsite_mae = torch.zeros((self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
onsite_rmse = torch.zeros((self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
ist = 0
for i,iorb in enumerate(self.idp.full_basis):
jst = 0
li = anglrMId[re.findall(r"[a-zA-Z]+", iorb)[0]]
for j,jorb in enumerate(self.idp.full_basis):
orbpair = iorb + "-" + jorb
lj = anglrMId[re.findall(r"[a-zA-Z]+", jorb)[0]]

# constructing hopping blocks
if iorb == jorb:
factor = 0.5
else:
factor = 1.0

# constructing onsite blocks
if i <= j:
onsite_mae[ist:ist+2*li+1,jst:jst+2*lj+1] = factor * self.stats["onsite"][at]["mae_per_block_element"][self.idp.orbpair_maps[orbpair]].reshape(2*li+1, 2*lj+1)
onsite_rmse[ist:ist+2*li+1,jst:jst+2*lj+1] = factor * self.stats["onsite"][at]["rmse_per_block_element"][self.idp.orbpair_maps[orbpair]].reshape(2*li+1, 2*lj+1)

jst += 2*lj+1
ist += 2*li+1

onsite_mae += onsite_mae.clone().T
onsite_rmse += onsite_rmse.clone().T

imask = self.idp.mask_to_basis[tp]
onsite_mae = onsite_mae[imask][:,imask]
onsite_rmse = onsite_rmse[imask][:,imask]

plt.matshow(onsite_mae.detach().cpu().numpy(), cmap="Blues", vmin=0, vmax=1e-3)
plt.title("MAE")
plt.colorbar()
plt.show()

plt.matshow(onsite_rmse.detach().cpu().numpy(), cmap="Blues", vmin=0, vmax=1e-3)
plt.title("RMSE")
plt.colorbar()
plt.show()

# compute the hopping per block err
print("Hopping:")
for bt, tp in self.idp.bond_to_type.items():
print(f"{bt}:")
print(f"MAE: {self.stats['hopping'][bt]['mae']}")
print(f"RMSE: {self.stats['hopping'][bt]['rmse']}")
hopping_mae = torch.zeros((self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
hopping_rmse = torch.zeros((self.idp.full_basis_norb, self.idp.full_basis_norb,), dtype=self.dtype, device=self.device)
ist = 0
for i,iorb in enumerate(self.idp.full_basis):
jst = 0
li = anglrMId[re.findall(r"[a-zA-Z]+", iorb)[0]]
for j,jorb in enumerate(self.idp.full_basis):
orbpair = iorb + "-" + jorb
lj = anglrMId[re.findall(r"[a-zA-Z]+", jorb)[0]]

# constructing hopping blocks
if iorb == jorb:
factor = 0.5
else:
factor = 1.0

# constructing onsite blocks
if i <= j:
hopping_mae[ist:ist+2*li+1,jst:jst+2*lj+1] = factor * self.stats["hopping"][bt]["mae_per_block_element"][self.idp.orbpair_maps[orbpair]].reshape(2*li+1, 2*lj+1)
hopping_rmse[ist:ist+2*li+1,jst:jst+2*lj+1] = factor * self.stats["hopping"][bt]["rmse_per_block_element"][self.idp.orbpair_maps[orbpair]].reshape(2*li+1, 2*lj+1)

jst += 2*lj+1
ist += 2*li+1

hopping_mae += hopping_mae.clone().T
hopping_rmse += hopping_rmse.clone().T

imask = self.idp.mask_to_basis[tp]
jmask = self.idp.mask_to_basis[tp]
hopping_mae = hopping_mae[imask][:,jmask]
hopping_rmse = hopping_rmse[imask][:,jmask]

plt.matshow(hopping_mae.detach().cpu().numpy(), cmap="Blues", vmin=0, vmax=1e-3)
plt.title("MAE")
plt.colorbar()
plt.show()

plt.matshow(hopping_mae.detach().cpu().numpy(), cmap="Blues", vmin=0, vmax=1e-3)
plt.title("RMSE")
plt.colorbar()
plt.show()





def __cal_norm__(self, irreps: Irreps, x: torch.Tensor):
id = 0
Expand Down
40 changes: 25 additions & 15 deletions dptb/postprocess/totbplas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def get_cell(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options:
data = AtomicData.to_AtomicDataDict(data.to(self.device))
data = self.model.idp(data)

if self.overlap == True:
assert data.get(AtomicDataDict.EDGE_OVERLAP_KEY) is not None

# get the HR
data = self.model(data)

if self.overlap == True:
assert data.get(AtomicDataDict.EDGE_OVERLAP_KEY) is not None

cell = data[AtomicDataDict.CELL_KEY]
cell_inv = cell.inverse()
tbplus_cell = tb.PrimitiveCell(lat_vec=cell.cpu(), unit=tb.ANG)
tbplas_cell = tb.PrimitiveCell(lat_vec=cell.cpu(), unit=tb.ANG)

orbs = {}
norbs = {}
Expand Down Expand Up @@ -108,7 +108,7 @@ def get_cell(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options:

energy = onsite_blocks[self.model.idp.orbpair_maps[forb+"-"+forb]].reshape(2*l+1, 2*l+1)[m+l, m+l].item()

tbplus_cell.add_orbital(
tbplas_cell.add_orbital(
(cell_inv @ data[AtomicDataDict.POSITIONS_KEY][i]).cpu(),
energy=energy - e_fermi,
label=orbs[isymbol][io]
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_cell(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options:
idx = orbsidict[str(i)+"-"+orbs[isymbol][xo]]
idy = orbsidict[str(i)+"-"+orbs[isymbol][yo]]
if abs(energy) > 1e-7 and not idx==idy:
tbplus_cell._hopping_dict.add_hopping(rn=(0, 0, 0,), orb_i=idx, orb_j=idy, energy=energy)
tbplas_cell._hopping_dict.add_hopping(rn=(0, 0, 0,), orb_i=idx, orb_j=idy, energy=energy)
# off-diagonal part

for i, iindex in enumerate(data[AtomicDataDict.EDGE_TYPE_KEY].flatten()):
Expand Down Expand Up @@ -166,11 +166,21 @@ def get_cell(self, data: Union[AtomicData, ase.Atoms, str], AtomicData_options:
energy = hopping_blocks[self.model.idp.orbpair_maps[forbx+"-"+forby]].reshape(2*lx+1, 2*ly+1)[mx+lx, my+ly].item()
idx = orbsidict[str(indx)+"-"+orbs[isymbol][xo]]
idy = orbsidict[str(jndx)+"-"+orbs[jsymbol][yo]]
if abs(energy) > 1e-7 and not idx==idy:
if abs(energy) > 1e-7:
rn = data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i].cpu().numpy()
tbplus_cell._hopping_dict.add_hopping(rn=(rn[0], rn[1], rn[2]), orb_i=idx, orb_j=idy, energy=energy)

return tbplus_cell
rev = tbplas_cell.hoppings.get((-rn[0], -rn[1], -rn[2]))
if rev is not None:
rev = rev.get((idy,idx))
if rev is not None:
# in case of the hopping is not symmetric
energy = (energy + rev) / 2
tbplas_cell._hopping_dict.add_hopping(rn=(rn[0], rn[1], rn[2]), orb_i=idx, orb_j=idy, energy=energy)
tbplas_cell._hopping_dict.add_hopping(rn=(-rn[0], -rn[1], -rn[2]), orb_i=idy, orb_j=idx, energy=energy)
else:
tbplas_cell._hopping_dict.add_hopping(rn=(rn[0], rn[1], rn[2]), orb_i=idx, orb_j=idy, energy=energy)


return tbplas_cell



Expand Down Expand Up @@ -207,7 +217,7 @@ def write(self):
factor = 13.605662285137

lat = self.structase.cell
tbplus_cell = tb.PrimitiveCell(lat_vec=lat, unit=tb.ANG)
tbplas_cell = tb.PrimitiveCell(lat_vec=lat, unit=tb.ANG)

if os.path.exists(os.path.join(self.results_path, "HR.pth")):
f = torch.load(os.path.join(self.results_path, "HR.pth"))
Expand Down Expand Up @@ -260,7 +270,7 @@ def write(self):
orbsidict[str(i)+"-"+orbs[label][io]] = orbcount # e.g.: [1-s,1-py ...]
orbcount += 1

tbplus_cell.add_orbital(self.structase[i].scaled_position,
tbplas_cell.add_orbital(self.structase[i].scaled_position,
energy=onsite_blocks[io,io].item(), label=orbs[label][io])
# accum_norbs = np.cumsum(accum_norbs)
# off-diagonal part
Expand All @@ -279,13 +289,13 @@ def write(self):
idy = orbsidict[str(j)+"-"+orbs[jlabel][yo]]
if abs(energy) > 1e-7 and not \
((R_bonds[ix] < 1e-14) & (idx==idy)):
tbplus_cell._hopping_dict.add_hopping(rn=(Rx, Ry, Rz), orb_i=idx, orb_j=idy, energy=energy)
tbplas_cell._hopping_dict.add_hopping(rn=(Rx, Ry, Rz), orb_i=idx, orb_j=idy, energy=energy)

if self.jdata["cal_fermi"]:

nele = self.jdata["nele"]

super_cell = tb.SuperCell(tbplus_cell, dim=self.jdata["supercell"], pbc=self.jdata["pbc"])
super_cell = tb.SuperCell(tbplas_cell, dim=self.jdata["supercell"], pbc=self.jdata["pbc"])
sample = tb.Sample(super_cell)
sample.rescale_ham()

Expand Down Expand Up @@ -313,4 +323,4 @@ def write(self):
else:
e_fermi = self.jdata.get("e_fermi", 0)

return tbplus_cell, e_fermi
return tbplas_cell, e_fermi

0 comments on commit a226cfd

Please sign in to comment.