Source code for kcwidrp.primitives.GetAtlasLines

from keckdrpframework.primitives.base_primitive import BasePrimitive
from kcwidrp.core.bokeh_plotting import bokeh_plot
from kcwidrp.core.kcwi_plotting import save_plot
from kcwidrp.primitives.kcwi_file_primitives import plotlabel

from bokeh.plotting import figure
from bokeh.models import Range1d
import numpy as np
import scipy as sp
from scipy.interpolate import interpolate
from scipy.signal.windows import boxcar
from scipy.optimize import curve_fit
from scipy.stats import sigmaclip
import time
import os
import traceback


[docs]def gaus(x, a, mu, sigma): """ Gaussian fitting function. Args: x (float): x value a (float): amplitude mu (float): x position of center of Gaussian sigma (float): Gaussian width Returns: Gaussian value at x position """ return a * np.exp(-(x - mu) ** 2 / (2. * sigma ** 2))
[docs]def get_line_window(y, c, thresh=0., logger=None, strict=False, maxwin=100, frac_max=0.5): """ Find a window that includes the specified amount of the line. Get a useful window on a given line to allow finding an accurate peak for each line. Start at input center and iterate above and below until the window encompasses the specified fraction of the line maximum flux. Rejects lines that are problematic. Args: y (list of floats): flux values c (float): the nominal center of the line in pixels thresh (float): threshhold for rejection (default: 0.) logger (log object): logging instance for output strict (bool): should we apply strict criterion? (default: False) maxwin (int): largest possible window in pixels (default: 100) frac_max (float): fraction of maximum flux to encompass (default: 0.5) :returns: - x0 (int): lower limit of window in pixels - x1 (int): upper limit of window in pixels - count (int): number of pixels in window """ verbose = logger is not None nx = len(y) # check edges if c < 2 or c > nx - 2: if verbose: logger.info("input center too close to edge") return None, None, 0 # get initial values x0 = c - 2 x1 = c + 2 mx = np.nanmax(y[x0:x1+1]) count = 5 # check low side if x0 - 1 < 0: if verbose: logger.info("max check: low edge hit") return None, None, 0 while y[x0-1] > mx: x0 -= 1 count += 1 if x0 - 1 < 0: if verbose: logger.info("Max check: low edge hit") return None, None, 0 # check high side if x1 + 1 >= nx: if verbose: logger.info("max check: high edge hit") return None, None, 0 while y[x1+1] > mx: x1 += 1 count += 1 if x1 + 1 >= nx: if verbose: logger.info("Max check: high edge hit") return None, None, 0 # how big is our window? if (x1 - x0) > maxwin: if verbose: logger.info("Window expanded beyond limit") return None, None, 0 # adjust starting window to center on max cmx = x0 + y[x0:x1+1].argmax() x0 = cmx - 2 x1 = cmx + 2 mx = np.nanmax(y[x0:x1 + 1]) # make sure max is high enough if mx < thresh: return None, None, 0 # # expand until we get to max_frac fmx = mx * frac_max # # Low index side prev = mx while y[x0] > fmx: if y[x0] > mx or x0 <= 0 or y[x0] > prev: if verbose: if y[x0] > mx: logger.info("frac_max check: low index err - missed max") if x0 <= 0: logger.info("frac_max check: low index err - at edge") if y[x0] > prev: logger.info("frac_max check: low index err - wiggly") return None, None, 0 prev = y[x0] x0 -= 1 count += 1 # High index side prev = mx while y[x1] > fmx: if y[x1] > mx or x1 >= nx or y[x1] > prev: if verbose: if y[x1] > mx: logger.info("frac_max check: high index err - missed max") if x1 >= nx: logger.info("frac_max check: high index err - at edge") if y[x1] > prev: logger.info("frac_max check: high index err - wiggly") return None, None, 0 prev = y[x1] if x1 < (nx-1): x1 += 1 count += 1 else: if verbose: logger.info("Edge encountered") return None, None, 0 if strict: # where did we end up? if c < x0 or x1 < c: if verbose: logger.info("initial position outside final window") return None, None, 0 return x0, x1, count
# END: get_line_window()
[docs]def findpeaks(x, y, wid, sth, ath, pkg=None, verbose=False): """ Find peaks in spectrum Uses a gradient to determine where peaks are and then performs some cleaning of blended lines using sigma rejection. Args: x (list of floats): x values of vector y (list of floats): y values of vector wid (int): boxcar smoothing width in pixels sth (float): slope threshhold ath (float): amplitude threshhold pkg (float): limit on peak wandering in pixels verbose (bool): control output messages :returns: - list of peaks - average sigma of cleaned peaks - list of y values at peaks """ # derivative grad = np.gradient(y) # smooth derivative win = boxcar(wid) d = sp.signal.convolve(grad, win, mode='same') / sum(win) # size nx = len(x) # set up windowing if not pkg: pkg = wid hgrp = int(pkg/2) hgt = [] pks = [] sgs = [] # loop over spectrum # limits to avoid edges given pkg for i in np.arange(pkg, (nx - pkg)): # find zero crossings if np.sign(d[i]) > np.sign(d[i+1]): # pass slope threshhold? if (d[i] - d[i+1]) > sth * y[i]: # pass amplitude threshhold? if y[i] > ath or y[i+1] > ath: # get subvectors around peak in window xx = x[(i-hgrp):(i+hgrp+1)] yy = y[(i-hgrp):(i+hgrp+1)] if len(yy) > 3: try: # gaussian fit res, _ = curve_fit(gaus, xx, yy, p0=[y[i], x[i], 1.]) # check offset of fit from initial peak r = abs(x - res[1]) t = r.argmin() if abs(i - t) > pkg: if verbose: print(i, t, x[i], res[1], x[t]) else: hgt.append(res[0]) pks.append(res[1]) sgs.append(abs(res[2])) except RuntimeError: continue # clean by sigmas cvals = [] cpks = [] sgmn = None if len(pks) > 0: cln_sgs, low, upp = sigmaclip(sgs, low=3., high=3.) for i in range(len(pks)): # clean only blends (overly wide lines) if sgs[i] < upp: cpks.append(pks[i]) cvals.append(hgt[i]) sgmn = cln_sgs.mean() # sgmd = float(np.nanmedian(cln_sgs)) else: print("No peaks found!") return cpks, sgmn, cvals
# END: findpeaks()
[docs]class GetAtlasLines(BasePrimitive): """ Get relevant atlas line positions and wavelengths Generates a list of well observed atlas lines for generating the final wavelength solution. Uses the following configuration parameters: * FRACMAX: fraction of max to use for window for fitting (defaults to 0.5) * LINELIST: an input line list to use instead of determining one on the fly """ def __init__(self, action, context): BasePrimitive.__init__(self, action, context) self.logger = context.pipeline_logger self.action.args.atminrow = None self.action.args.atmaxrow = None self.action.args.atminwave = None self.action.args.atmaxwave = None self.action.args.at_wave = None self.action.args.at_flux = None def _pre_condition(self): self.logger.info("Checking for master arc") if 'MARC' in self.action.args.ccddata.header['IMTYPE']: return True else: return False def _perform(self): """Get atlas line positions for wavelength fitting""" self.logger.info("Finding isolated atlas lines") verbose = (self.config.instrument.verbose > 1) frac_max = self.config.instrument.FRACMAX self.logger.info("Finding line windows using fraction of line max of " "%.2f" % frac_max) # get atlas wavelength range # get pixel values (no longer centered in the middle) specsz = len(self.context.arcs[self.config.instrument.REFBAR]) xvals = np.arange(0, specsz) # min, max rows, trimming the ends minrow = 50 maxrow = specsz - 50 # wavelength range mnwvs = [] mxwvs = [] refbar_disp = 1. # Get wavelengths for each bar for b in range(self.config.instrument.NBARS): waves = np.polyval(self.action.args.twkcoeff[b], xvals) mnwvs.append(np.min(waves)) mxwvs.append(np.max(waves)) if b == self.config.instrument.REFBAR: refbar_disp = self.action.args.twkcoeff[b][-2] self.logger.info("Ref bar (%d) dispersion = %.3f Ang/px" % (self.config.instrument.REFBAR, refbar_disp)) # Get extrema (trim ends a bit) minwav = min(mnwvs) + 10. maxwav = max(mxwvs) - 10. wave_range = maxwav - minwav # Do we have a dichroic? if self.action.args.dich: if self.action.args.camera == 0: # Blue maxwav = min([maxwav, 5620.]) elif self.action.args.camera == 1: # Red minwav = max([minwav, 5580.]) else: self.logger.error("Camera keyword not defined!") dichroic_fraction = (maxwav - minwav) / wave_range # Get corresponding atlas range minrw = [i for i, v in enumerate(self.action.args.refwave) if v >= minwav][0] maxrw = [i for i, v in enumerate(self.action.args.refwave) if v <= maxwav][-1] self.logger.info("Min, Max wave (A): %.2f, %.2f" % (minwav, maxwav)) if self.action.args.dich: self.logger.info("Dichroic fraction: %.3f" % dichroic_fraction) # store atlas ranges self.action.args.atminrow = minrw self.action.args.atmaxrow = maxrw self.action.args.atminwave = minwav self.action.args.atmaxwave = maxwav self.action.args.dichroic_fraction = dichroic_fraction # output filename stub atfnam = "arc_%05d_%s_%s_%s_atlines" % \ (self.action.args.ccddata.header['FRAMENO'], self.action.args.illum, self.action.args.grating, self.action.args.ifuname) # check if line list was given on command line if self.config.instrument.LINELIST: with open(self.config.instrument.LINELIST) as llfn: atlines = llfn.readlines() refws = [] refas = [] for line in atlines: if '#' in line: continue refws.append(float(line.split()[0])) refas.append(float(line.split()[1])) self.logger.info("Read %d lines from %s" % (len(refws), self.config.instrument.LINELIST)) else: # get atlas sub spectrum atspec = self.action.args.reflux[minrw:maxrw] atwave = self.action.args.refwave[minrw:maxrw] # get ref bar arc spectrum, pixel values, and prelim wavelengths subxvals = xvals[minrow:maxrow] subyvals = self.context.arcs[self.config.instrument.REFBAR][ minrow:maxrow].copy() subwvals = np.polyval( self.action.args.twkcoeff[self.config.instrument.REFBAR], subxvals) # smooth subyvals win = boxcar(3) subyvals = sp.signal.convolve(subyvals, win, mode='same') / sum(win) # find good peaks in arc spectrum smooth_width = 4 # in pixels # peak width peak_width = int(self.action.args.atsig/abs(refbar_disp)) if peak_width < 4: peak_width = 4 # slope threshold slope_thresh = 0.7 * smooth_width / 2. / 100. # slope_thresh = 0.7 * smooth_width / 1000. # more severe for arc # slope_thresh = 0.016 / peak_width # get amplitude threshold ampl_thresh = 0. self.logger.info("Using a peak_width of %d px, a slope_thresh of " "%.5f a smooth_width of %d and an ampl_thresh of " "%.3f" % (peak_width, slope_thresh, smooth_width, ampl_thresh)) arc_cent, avwsg, arc_hgt = findpeaks(subwvals, subyvals, smooth_width, slope_thresh, ampl_thresh, peak_width) avwfwhm = avwsg * 2.354 self.logger.info("Found %d lines with <sig> = %.3f (A)," " <FWHM> = %.3f (A)" % (len(arc_cent), avwsg, avwfwhm)) # fitting window based on grating type if 'H' in self.action.args.grating or \ 'M' in self.action.args.grating: fwid = avwfwhm else: fwid = avwsg # clean near neighbors spec_cent = arc_cent spec_hgt = arc_hgt # # generate an atlas line list refws = [] # atlas line wavelength refas = [] # atlas line amplitude rej_win_w = [] # win rejected atlas line wavelength rej_win_a = [] # win rejected atlas line amplitude rej_fit_w = [] # fit rejected atlas line wavelength rej_fit_a = [] # fit rejected atlas line amplitude rej_par_w = [] # par rejected atlas line wavelength rej_par_a = [] # par rejected atlas line amplitude rej_dup_w = [] # dup rejected atlas line wavelength rej_dup_a = [] # dup rejected atlas line wavelength rej_fnt_w = [] # fnt rejected atlas line wavelength rej_fnt_a = [] # fnt rejected atlas line wavelength nrej = 0 # look at each arc spectrum line for i, pk in enumerate(spec_cent): if pk <= minwav or pk >= maxwav: continue # get atlas pixel position corresponding to arc line try: line_x = [ii for ii, v in enumerate(atwave) if v >= pk][0] # get window around atlas line to fit minow, maxow, count = get_line_window( atspec, line_x, logger=(self.logger if verbose else None), frac_max=frac_max) except IndexError: count = 0 minow = None maxow = None self.logger.warning("line at edge: %d, %.2f, %.f2f" % (i, pk, max(atwave))) # is resulting window large enough for fitting? if count < 5 or not minow or not maxow: # keep track of fit rejected lines rej_win_w.append(pk) rej_win_a.append(spec_hgt[i]) nrej += 1 self.logger.info("Atlas window rejected for line %.3f" % pk) continue # get data to fit yvec = atspec[minow:maxow + 1] xvec = atwave[minow:maxow + 1] # attempt Gaussian fit try: fit, _ = curve_fit(gaus, xvec, yvec, p0=[spec_hgt[i], pk, 1.], maxfev=5000) except RuntimeError as e: # keep track of Gaussian fit rejected lines rej_fit_w.append(pk) rej_fit_a.append(spec_hgt[i]) nrej += 1 self.logger.info("Atlas Gaussian fit rejected for line " "%.3f" % pk) if verbose: tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) self.logger.info("".join(tb_str)) continue # get interpolation function of atlas line int_line = interpolate.interp1d(xvec, yvec, kind='cubic', bounds_error=False, fill_value='extrapolate') # use very dense pixel sampling x_dense = np.linspace(min(xvec), max(xvec), num=1000) # resample line with dense sampling y_dense = int_line(x_dense) # get peak amplitude and wavelength pki = y_dense.argmax() pkw = x_dense[pki] # calculate some diagnostic parameters for the line # how many atlas pixels have we moved? xoff = abs(pkw - fit[1]) / self.action.args.refdisp # what is the wavelength offset in Angstroms? woff = abs(pkw - pk) # what fraction of the canonical fit width is the line? wrat = abs(fit[2]) / fwid # can be neg or pos # current criteria for these diagnostic parameters if woff > 5. or xoff > 1.5 or wrat > 1.1: # keep track of par rejected atlas lines rej_par_w.append(pk) rej_par_a.append(spec_hgt[i]) nrej += 1 self.logger.info("Atlas line parameters rejected for line " "%.3f" % pk) self.logger.info("woff = %.3f, xoff = %.2f, wrat = %.3f" % (woff, xoff, wrat)) continue # check for duplicated lines if pkw in refws: rej_dup_w.append(pk) rej_dup_a.append(spec_hgt[i]) nrej += 1 self.logger.info("Atlas line duplicated for line %.3f" % pk) continue self.logger.info("Atlas line accepted: %.3f" % pkw) refws.append(pkw) refas.append(y_dense[pki]) # eliminate faintest lines if we have a large number self.logger.info("number of remaining lines: %d" % len(refas)) if len(refas) > 400: # sort on flux sf = np.argsort(refas) refws = np.asarray(refws)[sf] refas = np.asarray(refas)[sf] # remove faintest two-thirds hlim = int(len(refas) * 0.67) refws = refws[hlim:] refas = refas[hlim:] rej_fnt_w = refws[:hlim] rej_fnt_a = refas[:hlim] # sort back onto wavelength sw = np.argsort(refws) refws = refws[sw].tolist() refas = refas[sw].tolist() self.logger.info("Using %d generated lines" % len(refws)) # plot final list of Atlas lines and show rejections norm_fac = np.nanmax(atspec) if self.config.instrument.plot_level >= 1: p = figure(title=plotlabel(self.action.args) + "ATLAS LINES Ngood = %d, Nrej = %d" % (len(refws), nrej), x_axis_label="Wavelength (A)", y_axis_label="Normalized Flux", plot_width=self.config.instrument.plot_width, plot_height=self.config.instrument.plot_height) p.line(subwvals, subyvals / np.nanmax(subyvals), legend_label='RefArc', color='lightgray') p.line(atwave, atspec / norm_fac, legend_label='Atlas', color='blue') # Rejected: window bad if len(rej_win_w) > 0: p.diamond(rej_win_w, rej_win_a / norm_fac, legend_label='WinRej', color='cyan', size=8) # Rejected: fit failure if len(rej_fit_w) > 0: p.diamond(rej_fit_w, rej_fit_a / norm_fac, legend_label='FitRej', color='red', size=8) # Rejected: line parameter outside range if len(rej_par_w) > 0: p.diamond(rej_par_w, rej_par_a / norm_fac, legend_label='ParRej', color='orange', size=8) # Rejected: duplicated line if len(rej_dup_w) > 0: p.diamond(rej_dup_w, rej_dup_a / norm_fac, legend_label='DupRej', color='brown', size=8) # Rejected: faint if len(rej_fnt_w) > 0: p.diamond(rej_fnt_w, rej_fnt_a / norm_fac, legend_label='FntRej', color='magenta', size=8) # Kept p.diamond(refws, refas / norm_fac, legend_label='Kept', color='green', size=16) p.line([minwav, minwav], [-0.1, 1.1], legend_label='WavLim', color='brown') p.line([maxwav, maxwav], [-0.1, 1.1], color='brown') p.x_range = Range1d(min([min(subwvals), minwav-10.]), max(subwvals)) p.y_range = Range1d(-0.04, 1.04) bokeh_plot(p, self.context.bokeh_session) if self.config.instrument.plot_level >= 2: input("Next? <cr>: ") else: time.sleep(self.config.instrument.plot_pause) save_plot(p, filename=atfnam+".png") # output directory output_dir = os.path.join(self.config.instrument.cwd, self.config.instrument.output_directory) # write out final atlas line list atlines = np.array([refws, refas]) atlines = atlines.T with open(os.path.join(output_dir, atfnam + '.txt'), 'w') as atlfn: np.savetxt(atlfn, atlines, fmt=['%12.3f', '%12.3f']) # store wavelengths, fluxes self.action.args.at_wave = refws self.action.args.at_flux = refas self.logger.info("Final atlas list has %d lines" % len(refws)) log_string = GetAtlasLines.__module__ self.action.args.ccddata.header['HISTORY'] = log_string self.logger.info(log_string) return self.action.args
# END: class GetAtlasLines()