Source code for gyrointerp.helpers

"""
Helper functions for other parts of gyro-interp.  Useful contents:

    | ``get_summary_statistics``
    | ``sample_ages_from_pdf``
"""
#############
## LOGGING ##
#############
import logging
from gyrointerp import log_sub, log_fmt, log_date_fmt

DEBUG = False
if DEBUG:
    level = logging.DEBUG
else:
    level = logging.INFO
LOGGER = logging.getLogger(__name__)
logging.basicConfig(
    level=level,
    style=log_sub,
    format=log_fmt,
    datefmt=log_date_fmt,
)

LOGDEBUG = LOGGER.debug
LOGINFO = LOGGER.info
LOGWARNING = LOGGER.warning
LOGERROR = LOGGER.error
LOGEXCEPTION = LOGGER.exception

#############
## IMPORTS ##
#############
import warnings
import os
from os.path import join
import numpy as np, pandas as pd
from scipy.interpolate import interp1d
from scipy.integrate import quad, IntegrationWarning

from scipy import __version__ as scipyversion
from packaging import version
scipy_ver = version.parse(scipyversion)

# Per https://docs.scipy.org/doc/scipy/release/1.12.0-notes.html
if scipy_ver >= version.parse("1.12.0"):
    from scipy.integrate import cumulative_trapezoid as integration_func
else:
    from scipy.integrate import cumtrapz as integration_func

warnings.filterwarnings(
    "ignore", category=IntegrationWarning
)

[docs]def get_summary_statistics(age_grid, age_post): """ Given an age posterior probability density, ``age_post``, over a grid over ages, ``age_grid``, determine summary statistics for the posterior (its median, mean, +/-1 and 2-sigma intervals, etc). Do this by interpolating over the posterior probability function. Args: age_grid (np.ndarray): Array-like of ages, in units of megayears. For instance, the default *age_grid* in ``gyro_posterior.gyro_age_posterior`` is ``np.linspace(0, 3000, 500)``. age_post (np.ndarray): Posterior probability distribution for ages; length should match *age_grid*. The posterior probabilities returned by``gyro_posterior.gyro_age_posterior`` and ``gyro_posterior.gyro_age_posterior_list`` are examples that would work. Generally, this helper function works for any grid and probability distribution. Returns: dict : summary_statistics Dictionary containing keys and values for median, mean, peak (mode), +/-1sigma, +/-2sigma, +/-3sigma, and +/-1sigmapct. The units of all values are megayears, except for *+/-1sigmapct*, which is the relative +/-1-sigma uncertainty normalized by the median of the posterior, and is dimensionless. """ # This function is a thin wrapper to _given_grid_post_get_summary_statistics return _given_grid_post_get_summary_statistics( age_grid, age_post )
[docs]def sample_ages_from_pdf(age_grid, age_post, n_samples=1000): """ Draw samples from a given age posterior probability density function (PDF) using quadratic interpolation. Args: age_grid (np.ndarray): Array of ages (in megayears) representing the grid points of the PDF. age_post (np.ndarray): Array of posterior probabilities corresponding to the age grid points. n_samples (int, optional): Number of samples to draw from the PDF. Default is 1000. Returns: np.ndarray : age_samples Numpy array of shape (n_samples,) containing the drawn age samples from the PDF. """ # Normalize the posterior probability (PDF) age_pdf = age_post / np.trapz(age_post, age_grid) # Create a quadratic interpolation function for the PDF. Go linear to # avoid negative values. pdf_interp = interp1d(age_grid, age_pdf, kind='linear') # Generate a fine grid of ages for sampling. For an age_grid spanning 0 to # 5000 Myr, hard-coding n_grid = 1,000,000 implies a grid resolution of # 0.005 Myr, which is sufficiently small that the implied truncation error # downstream will be negligble. n_grid = 1000000 age_fine_grid = np.linspace(age_grid[0], age_grid[-1], n_grid) # Evaluate the interpolated PDF on the fine grid pdf_fine = pdf_interp(age_fine_grid) # Normalize the interpolated PDF pdf_fine /= np.trapz(pdf_fine, age_fine_grid) # Generate random samples from the interpolated PDF age_samples = np.random.choice(age_fine_grid, size=n_samples, p=pdf_fine/pdf_fine.sum()) return age_samples
def _deprecated_given_grid_post_get_summary_statistics(age_grid, age_post, N=int(1e5)): """ yields results consistent within 10% of the interpolation-based _given_grid_post_get_summary_statistics implementation below; deprecated however because this approach yields quantization at the grid level. """ age_peak = int(age_grid[np.argmax(age_post)]) df = pd.DataFrame({'age':age_grid, 'p':age_post}) try: sample_df = df.sample(n=N, replace=True, weights=df.p) except ValueError: outdict = { 'median': np.nan, 'peak': np.nan, 'mean': np.nan, '+1sigma': np.nan, '-1sigma': np.nan, '+2sigma': np.nan, '-2sigma': np.nan, '+3sigma': np.nan, '-3sigma': np.nan, '+1sigmapct': np.nan, '-1sigmapct': np.nan, } return outdict one_sig = 68.27/2 two_sig = 95.45/2 three_sig = 99.73/2 pct_50 = np.nanpercentile(sample_df.age, 50) p1sig = np.nanpercentile(sample_df.age, 50+one_sig) - pct_50 m1sig = pct_50 - np.nanpercentile(sample_df.age, 50-one_sig) p2sig = np.nanpercentile(sample_df.age, 50+two_sig) - pct_50 m2sig = pct_50 - np.nanpercentile(sample_df.age, 50-two_sig) p3sig = np.nanpercentile(sample_df.age, 50+three_sig) - pct_50 m3sig = pct_50 - np.nanpercentile(sample_df.age, 50-three_sig) outdict = { 'median': np.round(pct_50,2), 'peak': np.round(age_peak,2), 'mean': np.round(np.nanmean(sample_df.age),2), '+1sigma': np.round(p1sig,2), '-1sigma': np.round(m1sig,2), '+2sigma': np.round(p2sig,2), '-2sigma': np.round(m2sig,2), '+3sigma': np.round(p3sig,2), '-3sigma': np.round(m3sig,2), '+1sigmapct': np.round(p1sig/pct_50,2), '-1sigmapct': np.round(m1sig/pct_50,2), } return outdict def _given_grid_post_get_summary_statistics(age_grid, age_post): if not np.all(np.isfinite(age_post)): outdict = { 'median': np.nan, 'peak': np.nan, 'mean': np.nan, '+1sigma': np.nan, '-1sigma': np.nan, '+2sigma': np.nan, '-2sigma': np.nan, '+3sigma': np.nan, '-3sigma': np.nan, '+1sigmapct': np.nan, '-1sigmapct': np.nan, } return outdict age_peak = int(age_grid[np.argmax(age_post)]) # Normalize the posterior probability (PDF) age_pdf = age_post / np.trapz(age_post, age_grid) # Calculate the cumulative distribution function (CDF) age_cdf = integration_func(age_pdf, age_grid, initial=0) # Create interpolation functions for PDF and CDF pdf_interp = interp1d(age_grid, age_pdf, kind='linear') cdf_interp = interp1d(age_grid, age_cdf, kind='linear') # Calculate the median age median_age = _find_percentile(cdf_interp, age_grid, 0.5) # Calculate the mean age mean_age = _calculate_mean(pdf_interp, age_grid) # Calculate the +/- 1, 2, and 3 sigma intervals p1sig, m1sig = _find_sigma_interval(cdf_interp, age_grid, 0.6827) p2sig, m2sig = _find_sigma_interval(cdf_interp, age_grid, 0.9545) p3sig, m3sig = _find_sigma_interval(cdf_interp, age_grid, 0.9973) outdict = { 'median': np.round(median_age, 2), 'peak': np.round(age_peak, 2), 'mean': np.round(mean_age, 2), '+1sigma': np.round(p1sig, 2), '-1sigma': np.round(m1sig, 2), '+2sigma': np.round(p2sig, 2), '-2sigma': np.round(m2sig, 2), '+3sigma': np.round(p3sig, 2), '-3sigma': np.round(m3sig, 2), '+1sigmapct': np.round(p1sig / median_age, 2), '-1sigmapct': np.round(m1sig / median_age, 2), } return outdict def _find_percentile(cdf_interp, age_grid, percentile): def objective(x): return cdf_interp(x) - percentile return _find_root(objective, age_grid[0], age_grid[-1]) def _calculate_mean(pdf_interp, age_grid): def integrand(x): return x * pdf_interp(x) mean, _ = quad(integrand, age_grid[0], age_grid[-1]) return mean def _find_sigma_interval(cdf_interp, age_grid, sigma_fraction): median = _find_percentile(cdf_interp, age_grid, 0.5) p_sigma = _find_percentile(cdf_interp, age_grid, 0.5 + sigma_fraction / 2) m_sigma = _find_percentile(cdf_interp, age_grid, 0.5 - sigma_fraction / 2) return p_sigma - median, median - m_sigma def _find_root(func, a, b, tol=1e-6): fa, fb = func(a), func(b) assert fa * fb <= 0, "Root not bracketed" while abs(b - a) > tol: c = (a + b) / 2 fc = func(c) if fc == 0: return c elif fa * fc < 0: b, fb = c, fc else: a, fa = c, fc return (a + b) / 2
[docs]def prepend_colstr(colstr, df): # prepend a string, `colstr`, to all columns in a dataframe return df.rename( {c:colstr+c for c in df.columns}, axis='columns' )
[docs]def left_merge(df0, df1, col0, col1): # execute a left-join ensuring the columns are cast as strings df0[col0] = df0[col0].astype(str) df1[col1] = df1[col1].astype(str) return df0.merge( df1, left_on=col0, right_on=col1, how='left' )
[docs]def given_dr2_get_dr3_dataframes(dr2_source_ids, runid_dr2, runid_dr3, overwrite=False): # dr2_source_ids: np.ndarray of np.int64 Gaia DR2 source identifiers. # runid_dr2: arbitrary string to identify the DR2->DR3 xmatch query # runid_dr3: arbitrary (different) string to identify the DR3 query # pip install cdips from cdips.utils.gaiaqueries import ( given_dr2_sourceids_get_edr3_xmatch, given_source_ids_get_gaia_data ) LOGINFO(42*'-') LOGINFO(runid_dr2) # Crossmatch from Gaia DR2->DR3. dr2_x_dr3_df = given_dr2_sourceids_get_edr3_xmatch( dr2_source_ids, runid_dr2, overwrite=overwrite, enforce_all_sourceids_viable=True ) # Take the closest magnitude difference as the single match. # # In NGC-3532 case, yields matches for everything, largest angular distance # 1.3 arcseconds, largest magnitude difference G=0.06 mags. # # For Pleiades, trickier, since the sample goes fainter. Lack of proper # motion projection also leads to many errorneous cases. get_dr3_xm = lambda _df: ( _df.sort_values(by='abs_magnitude_difference'). drop_duplicates(subset='dr2_source_id', keep='first') ) s_dr3 = get_dr3_xm(dr2_x_dr3_df) LOGINFO(10*'-') LOGINFO(s_dr3.describe()) LOGINFO(10*'-') if len(s_dr3) != len(np.unique(dr2_source_ids)): LOGINFO('Got bad dr2<->dr3 match') LOGINFO(len(s_dr3), len(np.unique(dr2_source_ids))) raise AssertionError assert len(s_dr3) == len(np.unique(dr2_source_ids)) dr3_source_ids = np.array(s_dr3.dr3_source_id).astype(np.int64) gdf = given_source_ids_get_gaia_data( dr3_source_ids, runid_dr3, n_max=10000, overwrite=overwrite, enforce_all_sourceids_viable=True, which_columns='*', gaia_datarelease='gaiadr3' ) gdf = gdf.rename({"source_id":"dr3_source_id"}, axis='columns') selcols = ['dr3_source_id', 'ra', 'dec', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmdec', 'ruwe', 'phot_g_mean_flux_over_error', 'phot_g_mean_mag', 'phot_bp_mean_flux_over_error', 'phot_rp_mean_flux', 'phot_rp_mean_flux_over_error', 'phot_bp_rp_excess_factor', 'phot_bp_n_contaminated_transits', 'phot_bp_n_blended_transits', 'phot_rp_n_contaminated_transits', 'phot_rp_n_blended_transits', 'bp_rp', 'bp_g', 'g_rp', 'radial_velocity', 'radial_velocity_error', 'rv_method_used', 'rv_expected_sig_to_noise', 'vbroad', 'vbroad_error', 'l', 'b', 'ecl_lon', 'ecl_lat', 'non_single_star', 'teff_gspphot', 'teff_gspphot_lower', 'teff_gspphot_upper', 'azero_gspphot', 'ag_gspphot', 'ebpminrp_gspphot'] gdf = gdf[selcols] selcols =['dr2_source_id', 'dr3_source_id', 'angular_distance', 'magnitude_difference'] s_dr3 = s_dr3[selcols] return gdf, s_dr3
[docs]def get_population_hyperparameter_posterior_samples(): """ Access the posterior samples described in section 3.5 of BPH23. (These are generated by ``drivers.run_emcee_fit_gyro_model``). The returned numpy array is samples in the following parameters: a1/a0, y_g, logk0, logk1, log_f. The notation follows Sections 3.3-3.5 of BPH23. """ from gyrointerp.paths import CACHEDIR csvpath = join(CACHEDIR, "fit_120-Myr_300-Myr_Praesepe.csv.gz") if not os.path.exists(csvpath): # Pull the population-level hyperparameters from an external cache if # they are not already downloaded. dropboxlink = ( 'https://www.dropbox.com/s/ywe3z8ez2ll871m/fit_120-Myr_300-Myr_Praesepe.csv?dl=1' ) df = pd.read_csv(dropboxlink) df.to_csv(csvpath, index=False) LOGINFO(f"Downloaded {csvpath} and cached it locally.") else: df = pd.read_csv(csvpath) flat_samples = np.array(df) return flat_samples