Skip to content

Commit

Permalink
Merge pull request #151 from yucongalicechen/utils-update
Browse files Browse the repository at this point in the history
fix: update functions related to `diffpy.utils` update
  • Loading branch information
sbillinge authored Jan 14, 2025
2 parents a9fbbe5 + 402677b commit 08032ab
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 188 deletions.
23 changes: 23 additions & 0 deletions news/utils-updates.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* <news item>

**Changed:**

* Functions that use DiffractionObject` in `diffpy.utils` to follow the new API.

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
50 changes: 26 additions & 24 deletions src/diffpy/labpdfproc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import pandas as pd
from scipy.interpolate import interp1d

from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject

RADIUS_MM = 1
N_POINTS_ON_DIAMETER = 300
TTH_GRID = np.arange(1, 180.1, 0.1)
# Round down the last element if it's slightly above 180 due to floating point precision
TTH_GRID[-1] = 180.00
CVE_METHODS = ["brute_force", "polynomial_interpolation"]

# pre-computed datasets for polynomial interpolation (fast calculation)
Expand Down Expand Up @@ -191,14 +193,14 @@ def _cve_brute_force(diffraction_data, mud):
muls = np.array(muls) / abs_correction.total_points_in_grid
cve = 1 / muls

cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
cve_do.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
cve_do = DiffractionObject(
xarray=TTH_GRID,
yarray=cve,
xtype="tth",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
name=f"absorption correction, cve, for {diffraction_data.name}",
metadata=diffraction_data.metadata,
)
return cve_do

Expand All @@ -211,22 +213,22 @@ def _cve_polynomial_interpolation(diffraction_data, mud):
if mud > 6 or mud < 0.5:
raise ValueError(
f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. "
f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }."
f"Please rerun with a value within this range or specifying another method from {*CVE_METHODS, }."
)
coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [
interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS
]
muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e)
cve = 1 / muls

cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
cve_do.insert_scattering_quantity(
TTH_GRID,
cve,
"tth",
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
cve_do = DiffractionObject(
xarray=TTH_GRID,
yarray=cve,
xtype="tth",
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
name=f"absorption correction, cve, for {diffraction_data.name}",
metadata=diffraction_data.metadata,
)
return cve_do

Expand Down Expand Up @@ -257,7 +259,7 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype=
xtype str
the quantity on the independent variable axis, allowed values are {*XQUANTITIES, }
method str
the method used to calculate cve, must be one of {* CVE_METHODS, }
the method used to calculate cve, must be one of {*CVE_METHODS, }
Returns
-------
Expand All @@ -270,14 +272,14 @@ def compute_cve(diffraction_data, mud, method="polynomial_interpolation", xtype=
global_xtype = cve_do_on_global_grid.on_xtype(xtype)[0]
cve_on_global_xtype = cve_do_on_global_grid.on_xtype(xtype)[1]
newcve = np.interp(orig_grid, global_xtype, cve_on_global_xtype)
cve_do = Diffraction_object(wavelength=diffraction_data.wavelength)
cve_do.insert_scattering_quantity(
orig_grid,
newcve,
xtype,
metadata=diffraction_data.metadata,
name=f"absorption correction, cve, for {diffraction_data.name}",
cve_do = DiffractionObject(
xarray=orig_grid,
yarray=newcve,
xtype=xtype,
wavelength=diffraction_data.wavelength,
scat_quantity="cve",
name=f"absorption correction, cve, for {diffraction_data.name}",
metadata=diffraction_data.metadata,
)
return cve_do

Expand Down
12 changes: 6 additions & 6 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve
from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
from diffpy.utils.parsers.loaddata import loadData
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object


def define_arguments():
Expand Down Expand Up @@ -170,12 +170,12 @@ def main():
f"exists. Please rerun specifying -f if you want to overwrite it."
)

input_pattern = Diffraction_object(wavelength=args.wavelength)
xarray, yarray = loadData(filepath, unpack=True)
input_pattern.insert_scattering_quantity(
xarray,
yarray,
args.xtype,
input_pattern = DiffractionObject(
xarray=xarray,
yarray=yarray,
xtype=args.xtype,
wavelength=args.wavelength,
scat_quantity="x-ray",
name=filepath.stem,
metadata=load_metadata(args, filepath),
Expand Down
109 changes: 0 additions & 109 deletions src/diffpy/labpdfproc/mud_calculator.py

This file was deleted.

23 changes: 13 additions & 10 deletions src/diffpy/labpdfproc/tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import copy
from pathlib import Path

from diffpy.labpdfproc.mud_calculator import compute_mud
from diffpy.utils.scattering_objects.diffraction_objects import QQUANTITIES, XQUANTITIES
from diffpy.utils.tools import get_package_info, get_user_info
from diffpy.utils.diffraction_objects import ANGLEQUANTITIES, QQUANTITIES, XQUANTITIES
from diffpy.utils.tools import check_and_build_global_config, compute_mud, get_package_info, get_user_info

WAVELENGTHS = {"Mo": 0.71, "Ag": 0.59, "Cu": 1.54}
WAVELENGTHS = {"Mo": 0.71073, "Ag": 0.59, "Cu": 1.5406}
known_sources = [key for key in WAVELENGTHS.keys()]

# Exclude wavelength from metadata to prevent duplication,
Expand Down Expand Up @@ -154,7 +153,9 @@ def set_xtype(args):
"""
if args.xtype.lower() not in XQUANTITIES:
raise ValueError(f"Unknown xtype: {args.xtype}. Allowed xtypes are {*XQUANTITIES, }.")
args.xtype = "q" if args.xtype.lower() in QQUANTITIES else "tth"
args.xtype = (
"q" if args.xtype.lower() in QQUANTITIES else "tth" if args.xtype.lower() in ANGLEQUANTITIES else "d"
)
return args


Expand Down Expand Up @@ -224,7 +225,8 @@ def load_user_metadata(args):

def load_user_info(args):
"""
Update username and email using get_user_info function from diffpy.utils
Load user info into args. If args are not provided, call check_and_build_global_config function from
diffpy.utils to prompt the user for inputs. Otherwise, call get_user_info with the provided arguments.
Parameters
----------
Expand All @@ -236,10 +238,11 @@ def load_user_info(args):
the updated argparse Namespace with username and email inserted
"""
config = {"username": args.username, "email": args.email}
config = get_user_info(config)
args.username = config["username"]
args.email = config["email"]
if args.username is None or args.email is None:
check_and_build_global_config()
config = get_user_info(owner_name=args.username, owner_email=args.email)
args.username = config.get("owner_name")
args.email = config.get("owner_email")
return args


Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def user_filesystem(tmp_path):
f.write("good_data.xy \n")
f.write(f"{str(input_dir.resolve() / 'good_data.txt')}\n")

home_config_data = {"username": "home_username", "email": "[email protected]"}
home_config_data = {
"owner_name": "home_username",
"owner_email": "[email protected]",
"owner_orcid": "home_orcid",
}
with open(home_dir / "diffpyconfig.json", "w") as f:
json.dump(home_config_data, f)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from diffpy.utils.parsers import loadData
from diffpy.utils.parsers.loaddata import loadData


# Test that our readable and unreadable files are indeed readable and
Expand Down
23 changes: 11 additions & 12 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve
from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object
from diffpy.utils.diffraction_objects import DiffractionObject

params1 = [
([0.5, 3, 1], {(0.0, -0.5), (0.0, 0.0), (0.5, 0.0), (-0.5, 0.0), (0.0, 0.5)}),
Expand Down Expand Up @@ -59,11 +59,11 @@ def test_set_muls_at_angle(inputs, expected):


def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity="x-ray"):
test_do = Diffraction_object(wavelength=1.54)
test_do.insert_scattering_quantity(
xarray,
yarray,
xtype,
test_do = DiffractionObject(
xarray=xarray,
yarray=yarray,
xtype=xtype,
wavelength=1.54,
scat_quantity=scat_quantity,
name=name,
metadata={"thing1": 1, "thing2": "thing2"},
Expand All @@ -81,14 +81,13 @@ def _instantiate_test_do(xarray, yarray, xtype="tth", name="test", scat_quantity
def test_compute_cve(inputs, expected, mocker):
xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2])
expected_cve = np.array([0.5, 0.5, 0.5])
mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray)
mocker.patch("numpy.interp", return_value=expected_cve)
input_pattern = _instantiate_test_do(xarray, yarray)
actual_cve_do = compute_cve(input_pattern, mud=1, method="polynomial_interpolation", xtype=inputs[0])
expected_cve_do = _instantiate_test_do(
expected[0],
expected[1],
expected[2],
xarray=expected[0],
yarray=expected[1],
xtype=expected[2],
name="absorption correction, cve, for test",
scat_quantity="cve",
)
Expand Down Expand Up @@ -126,8 +125,8 @@ def test_apply_corr(mocker):
mocker.patch("numpy.interp", return_value=expected_cve)
input_pattern = _instantiate_test_do(xarray, yarray)
absorption_correction = _instantiate_test_do(
xarray,
expected_cve,
xarray=xarray,
yarray=expected_cve,
name="absorption correction, cve, for test",
scat_quantity="cve",
)
Expand Down
Loading

0 comments on commit 08032ab

Please sign in to comment.