Skip to content

Commit

Permalink
fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Jan 9, 2025
1 parent 6fe0cb9 commit 511a0c6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
22 changes: 14 additions & 8 deletions py/gpu_specter/extract/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def evalcoeffs(psfdata, wavelengths, specmin=0, nspec=None):
Args:
psfdata: PSF data from io.read_psf() of Gauss Hermite PSF file
wavelengths: 1D array of wavelengths
wavelengths: 1D or 2D array of wavelengths. if 2d the shape should be
(nspec0, nwave) where nspec0 is the number of
fibers in psfdata
Options:
specmin: first spectrum to include
Expand Down Expand Up @@ -63,19 +65,20 @@ def evalcoeffs(psfdata, wavelengths, specmin=0, nspec=None):
ww = ww[specmin:specmin+nspec]
L = np.polynomial.legendre.legvander(ww, meta['LEGDEG'])
nwave = wavelengths.shape[-1]
nghx = meta['GHDEGX']+1
nghy = meta['GHDEGY']+1
nghx = meta['GHDEGX'] + 1
nghy = meta['GHDEGY'] + 1
p['GH'] = np.zeros((nghx, nghy, nspec, nwave))
for name, coeff in zip(psfdata['PSF']['PARAM'], psfdata['PSF']['COEFF']):
name = name.strip()
coeff = coeff[specmin:specmin+nspec]
if wave2d:
curv = np.einsum('kji,ki->kj', L, coeff) # L.dot(coeff.T).T
curv = np.einsum('kji,ki->kj', L, coeff)
else:
curv = np.einsum('ji,ki->kj', L, coeff) # L.dot(coeff.T).T
curv = np.einsum('ji,ki->kj', L, coeff)
# L.dot(coeff.T).T
if name.startswith('GH-'):
i, j = map(int, name.split('-')[1:3])
p['GH'][i,j] = curv
p['GH'][i, j] = curv
else:
p[name] = curv

Expand All @@ -92,7 +95,8 @@ def calc_pgh(ispec, wavelengths, psfparams):
Args:
ispec : integer spectrum number
wavelengths : array of wavelengths to evaluate
wavelengths : array of wavelengths to evaluate
either 1d (nwave,) or 2d (nspec,nwave)
psfparams : dictionary of PSF parameters returned by evalcoeffs
returns pGHx, pGHy
Expand Down Expand Up @@ -199,7 +203,9 @@ def get_spots(specmin, nspec, wavelengths, psfdata):
Args:
specmin: first spectrum to include
nspec: number of spectra to evaluate spots for
wavelengths: 1D array of wavelengths
wavelengths: 1D or 2D array of wavelengths. if 2D the wavelength shape
needs to be (nspec0, nwave) where nspec0 is the number
of spectra in psfdata
psfdata: PSF data from io.read_psf() of Gauss Hermite PSF file
Returns:
Expand Down
16 changes: 11 additions & 5 deletions py/gpu_specter/extract/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def evalcoeffs(psfdata, wavelengths, specmin=0, nspec=None):
Args:
psfdata: PSF data from io.read_psf() of Gauss Hermite PSF file
wavelengths: 1D array of wavelengths
wavelengths: 1D or 2D array of wavelengths. if 2d the shape should be
(nspec0, nwave) where nspec0 is the number of
fibers in psfdata
Options:
specmin: first spectrum to include
Expand Down Expand Up @@ -104,9 +106,10 @@ def evalcoeffs(psfdata, wavelengths, specmin=0, nspec=None):
name = name.strip()
coeff = coeff[specmin:specmin+nspec]
if wave2d:
curv = cp.einsum('kji,ki->kj', L, coeff) # L.dot(coeff.T).T
curv = cp.einsum('kji,ki->kj', L, coeff)
else:
curv = cp.einsum('ji,ki->kj', L, coeff) # L.dot(coeff.T).T
curv = cp.einsum('ji,ki->kj', L, coeff)
# L.dot(coeff.T).T

if name.startswith('GH-'):
i, j = map(int, name.split('-')[1:3])
Expand All @@ -126,7 +129,8 @@ def calc_pgh(ispec, wavelengths, psfparams):
Calculate the pixelated Gauss Hermite for all wavelengths of a single spectrum
ispec : integer spectrum number
wavelengths : array of wavelengths to evaluate
wavelengths : array of wavelengths to evaluate
either 1d (nwave,) or 2d (nspec,nwave)
psfparams : dictionary of PSF parameters returned by evalcoeffs
returns pGHx, pGHy
Expand Down Expand Up @@ -238,7 +242,9 @@ def get_spots(specmin, nspec, wavelengths, psfdata):
Args:
specmin: first spectrum to include
nspec: number of spectra to evaluate spots for
wavelengths: 1D array of wavelengths
wavelengths: 1D or 2D array of wavelengths. if 2D the wavelength shape
needs to be (nspec0, nwave) where nspec0 is the number
of spectra in psfdata
psfdata: PSF data from io.read_psf() of Gauss Hermite PSF file
Returns:
Expand Down

0 comments on commit 511a0c6

Please sign in to comment.