Skip to content

Commit

Permalink
Make numpy shorter (#83)
Browse files Browse the repository at this point in the history
- replacing `np.dot` with `@`
- replacing the `numpy.something` calls with `from numpy import something` if functions appear several times
- replacing the `numpy.something` call with `np.something` if a function appears rarely
  • Loading branch information
mcocdawc authored Jan 13, 2025
1 parent 7ca8d14 commit 0eefb8c
Show file tree
Hide file tree
Showing 25 changed files with 504 additions and 512 deletions.
4 changes: 2 additions & 2 deletions example/kbe_polyacetylene.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# A supercell with 4 carbon & 4 hydrogen atoms is defined as unit cell in
# pyscf's periodic HF calculation

import numpy
import numpy as np
from pyscf.pbc import df, gto, scf

from quemb.kbe import BE, fragpart
Expand All @@ -14,7 +14,7 @@
b = 8.0
c = 2.455 * 2.0

lat = numpy.eye(3)
lat = np.eye(3)
lat[0, 0] = a
lat[1, 1] = b
lat[2, 2] = c
Expand Down
4 changes: 2 additions & 2 deletions example/molbe_dmrg_block2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# Garnet-Chan group at Caltech: https://block2.readthedocs.io/en/latest/index.html

import matplotlib.pyplot as plt
import numpy
import numpy as np
from pyscf import cc, fci, gto, scf

from quemb.molbe import BE, fragpart
from quemb.molbe.solver import DMRG_ArgsUser

# We'll consider the dissociation curve for a 1D chain of 8 H-atoms:
num_points = 3
seps = numpy.linspace(0.60, 1.6, num=num_points)
seps = np.linspace(0.60, 1.6, num=num_points)
fci_ecorr, ccsd_ecorr, ccsdt_ecorr, bedmrg_ecorr = [], [], [], []

for a in seps:
Expand Down
52 changes: 25 additions & 27 deletions src/quemb/kbe/autofrag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author(s): Oinam Romesh Meitei


import numpy
from numpy import arange, asarray, where
from numpy.linalg import norm
from pyscf import lib

Expand Down Expand Up @@ -83,10 +83,10 @@ def sidefunc(
rlist=[],
):
if ext_list == []:
main_list.extend(unit2[numpy.where(unit1 == Idx)[0]])
sub_list.extend(unit2[numpy.where(unit1 == Idx)[0]])
main_list.extend(unit2[where(unit1 == Idx)[0]])
sub_list.extend(unit2[where(unit1 == Idx)[0]])
else:
for sub_i in unit2[numpy.where(unit1 == Idx)[0]]:
for sub_i in unit2[where(unit1 == Idx)[0]]:
if sub_i in rlist:
continue
if sub_i in ext_list:
Expand All @@ -104,7 +104,7 @@ def sidefunc(
close_be3 = []

if be_type == "be3" or be_type == "be4":
for lmin1 in unit2[numpy.where(unit1 == Idx)[0]]:
for lmin1 in unit2[where(unit1 == Idx)[0]]:
for jdx, j in enumerate(coord):
if (
jdx not in unit1
Expand Down Expand Up @@ -396,11 +396,9 @@ def autogen(
lnk2 = 1

lattice_vector = cell.lattice_vectors()
Ts = lib.cartesian_prod(
(numpy.arange(lkpt[0]), numpy.arange(lkpt[1]), numpy.arange(lkpt[2]))
)
Ts = lib.cartesian_prod((arange(lkpt[0]), arange(lkpt[1]), arange(lkpt[2])))

Ls = numpy.dot(Ts, lattice_vector)
Ls = Ts @ lattice_vector

# 1-2-(1-2)-1-2
# * *
Expand Down Expand Up @@ -469,12 +467,12 @@ def autogen(
kmsites = []
ktsites = []

lunit = numpy.asarray(lunit)
runit = numpy.asarray(runit)
uunit = numpy.asarray(uunit)
dunit = numpy.asarray(dunit)
munit = numpy.asarray(munit)
tunit = numpy.asarray(tunit)
lunit = asarray(lunit)
runit = asarray(runit)
uunit = asarray(uunit)
dunit = asarray(dunit)
munit = asarray(munit)
tunit = asarray(tunit)

inter_dist = 1000.0
if twoD and interlayer:
Expand Down Expand Up @@ -1307,7 +1305,7 @@ def autogen(
continue
if be_type == "be3" or be_type == "be4":
if jdx in lunit:
lmin1 = runit[numpy.where(lunit == jdx)[0]]
lmin1 = runit[where(lunit == jdx)[0]]
if not twoD:
flist.extend(lmin1)
lsts.extend(lmin1)
Expand Down Expand Up @@ -1387,7 +1385,7 @@ def autogen(
bond=bond,
)
if jdx in runit:
rmin1 = lunit[numpy.where(runit == jdx)[0]]
rmin1 = lunit[where(runit == jdx)[0]]
if not twoD:
flist.extend(rmin1)
rsts.extend(rmin1)
Expand Down Expand Up @@ -1467,7 +1465,7 @@ def autogen(
)

if jdx in uunit:
umin1 = dunit[numpy.where(uunit == jdx)[0]]
umin1 = dunit[where(uunit == jdx)[0]]
add_check_k(umin1, flist, usts, kusts, 2)
if be_type == "be4":
for kdx, k in enumerate(coord):
Expand Down Expand Up @@ -1534,7 +1532,7 @@ def autogen(
bond=bond,
)
if jdx in dunit:
dmin1 = uunit[numpy.where(dunit == jdx)[0]]
dmin1 = uunit[where(dunit == jdx)[0]]
add_check_k(dmin1, flist, dsts, kdsts, nk1)

if be_type == "be4":
Expand Down Expand Up @@ -1602,7 +1600,7 @@ def autogen(
bond=bond,
)
if jdx in munit: #
mmin1 = tunit[numpy.where(munit == jdx)[0]]
mmin1 = tunit[where(munit == jdx)[0]]
add_check_k(mmin1, flist, msts, kmsts, nk1 * nk2)
if be_type == "be4":
for kdx, k in enumerate(coord):
Expand Down Expand Up @@ -1656,7 +1654,7 @@ def autogen(
)

if jdx in tunit:
tmin1 = munit[numpy.where(tunit == jdx)[0]]
tmin1 = munit[where(tunit == jdx)[0]]
add_check_k(tmin1, flist, tsts, ktsts, nk1 + 2)

if be_type == "be4":
Expand Down Expand Up @@ -1724,7 +1722,7 @@ def autogen(
pedg.append(kdx)
if be_type == "be4":
if kdx in lunit:
lmin1 = runit[numpy.where(lunit == kdx)[0]]
lmin1 = runit[where(lunit == kdx)[0]]
for zdx in lmin1:
if (
zdx in lsts
Expand All @@ -1740,7 +1738,7 @@ def autogen(
else:
klsts.append(nk1)
if kdx in runit:
rmin1 = lunit[numpy.where(runit == kdx)[0]]
rmin1 = lunit[where(runit == kdx)[0]]
for zdx in rmin1:
if (
zdx in rsts
Expand All @@ -1755,7 +1753,7 @@ def autogen(
else:
krsts.append(2)
if kdx in uunit:
umin1 = dunit[numpy.where(uunit == kdx)[0]]
umin1 = dunit[where(uunit == kdx)[0]]
for zdx in umin1:
if (
zdx in usts
Expand All @@ -1767,7 +1765,7 @@ def autogen(
usts.append(zdx)
kusts.append(2)
if kdx in dunit:
dmin1 = uunit[numpy.where(dunit == kdx)[0]]
dmin1 = uunit[where(dunit == kdx)[0]]
for zdx in dmin1:
if (
zdx in dsts
Expand All @@ -1779,7 +1777,7 @@ def autogen(
dsts.append(zdx)
kdsts.append(nk1)
if kdx in munit:
mmin1 = tunit[numpy.where(munit == kdx)[0]]
mmin1 = tunit[where(munit == kdx)[0]]
for zdx in mmin1:
if (
zdx in msts
Expand All @@ -1791,7 +1789,7 @@ def autogen(
msts.append(zdx)
kmsts.append(nk1 * nk2)
if kdx in tunit:
tmin1 = munit[numpy.where(tunit == kdx)[0]]
tmin1 = munit[where(tunit == kdx)[0]]
for zdx in tmin1:
if (
zdx in tsts
Expand Down
12 changes: 6 additions & 6 deletions src/quemb/kbe/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Author(s): Oinam Romesh Meitei

import numpy
from numpy import asarray, complex128, float64, zeros
from numpy.linalg import multi_dot
from pyscf import scf

Expand Down Expand Up @@ -33,21 +33,21 @@ def get_veff(eri_, dm, S, TA, hf_veff, return_veff0=False):
# construct rdm
nk, nao, neo = TA.shape
unused(nao)
P_ = numpy.zeros((neo, neo), dtype=numpy.complex128)
P_ = zeros((neo, neo), dtype=complex128)
for k in range(nk):
Cinv = numpy.dot(TA[k].conj().T, S[k])
Cinv = TA[k].conj().T @ S[k]
P_ += multi_dot((Cinv, dm[k], Cinv.conj().T))
P_ /= float(nk)

P_ = numpy.asarray(P_.real, dtype=numpy.double)
P_ = asarray(P_.real, dtype=float64)

eri_ = numpy.asarray(eri_, dtype=numpy.double)
eri_ = asarray(eri_, dtype=float64)
vj, vk = scf.hf.dot_eri_dm(eri_, P_, hermi=1, with_j=True, with_k=True)
Veff_ = vj - 0.5 * vk

# remove core contribution from hf_veff

Veff0 = numpy.zeros((neo, neo), dtype=numpy.complex128)
Veff0 = zeros((neo, neo), dtype=complex128)
for k in range(nk):
Veff0 += multi_dot((TA[k].conj().T, hf_veff[k], TA[k]))
Veff0 /= float(nk)
Expand Down
Loading

0 comments on commit 0eefb8c

Please sign in to comment.