# TODO: Add module documentation
# TODO: Complete documentation
import logging
import logging.config
from math import ceil
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, Annotated
import numpy as np
from numpy.typing import NDArray
from scipy.optimize import fsolve
from scipy.interpolate import interp1d
import astropy.constants as ct
import astropy.units as u
#import sys
#sys.path.append('..')
from bossa.constants import ZOH_SUN, Z_SUN, stellar_types
Quantity = float | u.quantity.Quantity
[docs]
@dataclass
class Length:
length: int
[docs]
def fix_unit(x, unit):
"""If x is an astropy Quantity, return x. Otherwise, return x with the astropy unit 'unit'."""
if not isinstance(x, u.quantity.Quantity):
x *= unit
else:
x = x.to(unit)
return x
[docs]
def FeH_to_OFe(FeH):
if FeH < -1:
return 0.5
else:
return -0.5 * FeH
def _FeH_from_ZOH(FeH, ZOH):
OFe = FeH_to_OFe(FeH)
return FeH - ZOH + ZOH_SUN + OFe
[docs]
def ZOH_from_FeH(FeH):
OFe = FeH_to_OFe(FeH)
return FeH + ZOH_SUN + OFe
[docs]
def ZOH_to_FeH(ZOH):
[FeH] = fsolve(_FeH_from_ZOH, np.array([-1]), args=ZOH)
return FeH
[docs]
def ZOH_to_FeH2(ZOH):
return ZOH-ZOH_SUN
[docs]
def FeH_to_Z(FeH):
return 10**FeH*Z_SUN
[docs]
def interpolate(ipX, ipY, X):
"""Interpolate between each line of a pair of arrays.
Parameters
----------
ipX : numpy array
2-dimensional array. Each line corresponds to the x coordinates of one set of points between which to
interpolate.
ipY : numpy array
2-dimensional array. Each line corresponds to the y coordinates of one set of points between which to
interpolate.
X : numpy array
1-dimensional array. x coordinates for which each line of ipX and ipY will be interpolated.
Returns
-------
Y : numpy array
1-dimensional array. Results of interpolation of ipX and ipY for each element of X.
"""
Y = []
for ipx, ipy in zip(ipX, ipY):
f = interp1d(ipx, ipy, kind='cubic')
Y.append(f(X))
Y = np.array(Y)
return Y
[docs]
def sample_histogram(sample, N=1, m_min=0.08):
limits = [m_min]
csis = []
j = 0
for i, m in enumerate(sample):
j += 1
if j == N:
try:
limit = m / 2 + sample[i + 1] / 2
except:
limit = (3 / 2) * m - limits[-1] / 2
limits.append(limit)
j = 0
for i, m_i in enumerate(limits[:-1]):
m_iplus1 = limits[i + 1]
delta_m = m_iplus1 - m_i
csi = N / delta_m
csis.append(csi)
CSI_X = np.sort(np.append(limits, limits))
CSI_Y = np.array([2 * [csi] for csi in csis]).flatten()
CSI_Y = np.insert(CSI_Y, 0, 0)
CSI_Y = np.append(CSI_Y, 0)
CSI = np.array([CSI_X, CSI_Y]).T
return CSI
[docs]
def logp_to_a(logp, m_tot):
p = (10**logp)*u.d
g = np.sqrt(4*np.pi**2/(ct.G.cgs*m_tot*ct.M_sun))
cgs_a = (p/g)**(2/3)
return cgs_a.to(u.au).value
[docs]
def a_to_logp(a, m_tot):
cgs_a = (a*u.au).to(u.cm)
g = np.sqrt(4*np.pi**2/(ct.G.cgs*m_tot*ct.M_sun))
cgs_p = g * cgs_a**(3/2)
return np.log10(cgs_p.to(u.d).value)
[docs]
def symmetrize_masses(row):
m1 = row['Mass(1)']
m2 = row['Mass(2)']
if m1 < m2:
row['Mass(1)'] = m2
row['Mass(2)'] = m1
return row
[docs]
def pull_snmass(row):
if row['Mass(1)'] == 0:
if row['Mass(SN)'] != 0:
row['Mass(1)'] = row['Mass(SN)']
row['Mass(2)'] = row['Mass(CP)']
return row
[docs]
def pull_snmass1(row):
snmass1 = 0
if row['Mass(1)'] == 0:
if row['Mass(SN)'] != 0:
snmass1 = row['Mass(SN)']
else:
snmass1 = row['Mass(1)']
return snmass1
[docs]
def pull_snmass2(row):
snmass2 = 0
if row['Mass(2)'] == 0:
if row['Mass(CP)'] != 0:
snmass2 = row['Mass(CP)']
else:
snmass2 = row['Mass(2)']
return snmass2
[docs]
def chirp_mass(row):
"""Calculate the chirp mass for a dataframe row."""
m1 = row['Mass_PostSN1']
m2 = row['Mass_PostSN2']
if m1 == 0:
return 0
else:
return (m1 * m2) ** (3 / 5) / (m1 + m2) ** (1 / 5)
[docs]
def bintype(bintype):
"""Translate the binary type from numeric ID to a string abbreviation."""
t1, t2 = bintype.split('+')
return stellar_types[t1] + stellar_types[t2]
[docs]
def mass_ratio(m1, m2):
"""Calculate the mass ratio q as greater mass/lesser mass."""
m1_ = m1
m2_ = m2
if m1 < m2:
m2 = m1_
m1 = m2_
if m1 == 0:
return 0
else:
return m2 / m1
[docs]
def create_logger(name=None, fpath=None, propagate=True, parent=None):
if parent is None:
logger = logging.getLogger(name)
else:
logger = parent.getChild(name)
logger.propagate = propagate
logger.setLevel(logging.DEBUG)
if parent is None:
streamformatter = logging.Formatter('%(levelname)s %(processName)s %(message)s')
streamhandler = logging.StreamHandler()
streamhandler.setLevel(logging.INFO)
streamhandler.setFormatter(streamformatter)
logger.addHandler(streamhandler)
if fpath is not None:
fileformatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')
filehandler = logging.FileHandler(fpath)
filehandler.setLevel(logging.DEBUG)
filehandler.setFormatter(fileformatter)
logger.addHandler(filehandler)
return logger
[docs]
def logp_from_row(row):
"""Recovers the LogP from an appropriately written seed."""
seed = row['SEED'].split('_')[0]
e = row['Eccentricity@ZAMS']
q = row['Mass_Ratio']
eq_seed = ''.join([str(int(np.trunc((q-np.float32(1e-6))*np.float32(1e6)))),
str(int(np.trunc(e*np.float32(1e3))))
])
logp_str = seed[:len(seed)-len(eq_seed)-1]
logp_float = np.float32(logp_str)/np.float32(1e5)
return logp_float
[docs]
def step(array, index_array, midpoint_i):
left_delta = array[midpoint_i-1] - array[midpoint_i]
if left_delta < 0:
sub_arr = array[:midpoint_i]
sub_indarr = index_array[:midpoint_i]
else:
sub_arr = array[midpoint_i:]
sub_indarr = index_array[midpoint_i:]
midpoint_i = ceil(len(sub_arr) / 2)
return sub_arr, sub_indarr, midpoint_i
[docs]
def valley_minimum(array, index_array):
midpoint_i = ceil(len(array)/2)
sub_arr = array
while len(sub_arr) > 1:
sub_arr, index_array, midpoint_i = step(sub_arr, index_array, midpoint_i)
return index_array[0], sub_arr[0]
[docs]
def get_bin_frequency_heights(x_array, bin_edges):
bin_frequencies = np.zeros(bin_edges.shape[0] - 1, np.float32)
prev_x_array_len = len(x_array)
for i, (ledge, uedge) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
x_array = x_array[x_array >= uedge]
x_array_len = len(x_array)
count = prev_x_array_len - x_array_len
prev_x_array_len = x_array_len
bin_frequencies[i] = count / (uedge - ledge)
return bin_frequencies
[docs]
def get_linear_fit(xy0, xy1):
x0, y0 = xy0
x1, y1 = xy1
slope = (y1 - y0) / (x1 - x0)
intercept = y0 - slope*x0
return np.array([x0, x1, slope, intercept])
[docs]
def get_linear_fit_area(linear_fit, x0, x1):
fitx0, fitx1, slope, intercept = linear_fit
return np.abs(slope * (x1*x1 - x0*x0)/2) + np.abs(intercept * (x1-x0))
[docs]
def get_bin_centers(bin_edges):
bin_centers = np.array([(x0+x1)/2 for x0, x1 in zip(bin_edges[:-1], bin_edges[1:])])
return bin_centers
[docs]
def enumerate_bin_edges(bin_edges):
return enumerate(zip(bin_edges[:-1], bin_edges[1:]))