from keckdrpframework.primitives.base_img import BaseImg
from kcwidrp.primitives.kcwi_file_primitives import kcwi_fits_reader, \
kcwi_fits_writer, strip_fname
from kcwidrp.primitives.GetAtlasLines import gaus
from kcwidrp.core.kcwi_get_std import kcwi_get_std
from kcwidrp.core.bokeh_plotting import bokeh_plot
from kcwidrp.core.kcwi_plotting import save_plot
from kcwidrp.core.bspline import Bspline
from bokeh.plotting import figure
import os
import time
import numpy as np
from scipy.optimize import curve_fit
from astropy.io import fits
[docs]class MakeMasterSky(BaseImg):
"""
Make master sky image.
Uses b-spline fits along with geometry maps to generate a master sky image
for sky subtraction.
This routine also handles the file `kcwi.sky`, which controls the master
sky generation. This file consists of one line per image, with the first
column indicating the raw object image to be sky-subtracted. The following
columns can either indicate a separate image to use for sky subtraction, the
filename of a mask fits image for masking object flux, or indicate that the
object is a continuum source and either automatically find the object, or
specify the location and width of the continuum source. Below are example
one-line entries and what they mean:
1. Skip sky subtraction for this particular object image:
* kr230925_00075.fits skip
2. Point to a different image for the sky (this assumes the \*_sky.fits
image has already been generated previously:
* kr230925_00075.fits kr230925_00076.fits
3. Indicate that a mask file should be used to mask object flux when
deriving the sky model (see kcwi_masksky_ds9.py):
* kr230925_00075.fits kr230925_00075_smsk.fits
4. Indicate that this is a bright continuum source and automatically mask
the continuum source from the sky model:
* kr230925_00075.fits cont
5. Indicate that this is a faint continuum source and specify the location
and the width of the continuum source (in pixels):
* kr230925_00075.fits cont 45.0 7.6
If no `kcwi.sky` file exists, or there is no entry for the input object
frame, then the entire image is used to generate the sky model.
It is good practice to run all the data through first, then inspect the
sky subtraction and see which frames will benefit from masking or from a
dedicated sky observation.
If a sky model is generated, the routine will write out a \*_sky.fits image
and add a sky entry in the proc table.
"""
def __init__(self, action, context):
BaseImg.__init__(self, action, context)
self.logger = context.pipeline_logger
def _pre_condition(self):
"""
Checks if we can create a master sky
"""
self.logger.info("Checking precondition for MakeMasterSky")
if self.config.instrument.skipsky:
self.logger.warning("Sky subtraction turned off, "
"skipping MakeMasterSky")
return False
suffix = 'sky' # self.action.args.new_type.lower()
ofn = self.action.args.name
rdir = self.config.instrument.output_directory
# Are we a standard star?
stdfile = None
stdname = None
if 'object' in self.action.args.imtype.lower():
self.logger.info("Checking OBJECT keyword")
stdfile, stdname = kcwi_get_std(
self.action.args.ccddata.header['OBJECT'], self.logger)
if not stdfile:
self.logger.info("Checking TARGNAME keyword")
stdfile, stdname = kcwi_get_std(
self.action.args.ccddata.header['TARGNAME'], self.logger)
else:
self.logger.warning("Not object type: %s" %
self.action.args.imtype)
self.action.args.stdfile = stdfile
self.action.args.stdname = stdname
# Is there a kcwi.sky file?
skyfile = None
skymask = None
contsky = False
cont_source_pos = None
cont_source_width = None
# check if kcwi.sky exists
if os.path.exists('kcwi.sky'):
self.logger.info("Reading kcwi.sky")
f = open('kcwi.sky')
skyproc = f.readlines()
f.close()
# is our file in the list?
for row in skyproc:
# skip comments
if row.startswith('#'):
continue
# skip empty lines
if len(row.split()) < 1:
continue
# Parse row:
# <raw sci file> <raw sky file> <optional mask file>
# OR
# <raw sci file> skip
# to disable sky subtraction
# Find match to current file
if ofn in row.split()[0]:
skyfile = row.split()[1]
# Should we skip sky subtraction?
if 'skip' in skyfile:
self.logger.info("Skipping sky subtraction for %s" %
ofn)
keycom = 'sky corrected?'
self.action.args.ccddata.header['SKYCOR'] = (False,
keycom)
return False
elif 'cont' in skyfile:
self.logger.info("Using continuum source local sky for"
" %s" % ofn)
contsky = True
if len(row.split()) == 4:
cont_source_pos = float(row.split()[2])
cont_source_width = float(row.split()[3])
self.logger.info("Using input continuum pos of"
"%.2f with width of %.2f" %
(cont_source_pos,
cont_source_width))
# Do we have an optional sky mask file?
elif len(row.split()) > 2:
skymask = row.split()[2]
self.logger.info("Found sky mask entry for %s: %s"
% (ofn, skymask))
self.logger.info("Found sky entry for %s: %s" % (ofn,
skyfile))
# Do have a mask file?
if skymask:
# Does it exist?
if os.path.exists(skymask):
self.logger.info("Using sky mask file: %s" % skymask)
else:
self.logger.warning("Sky mask file not found: %s" % skymask)
skymask = None
# Record results
self.action.args.skyfile = skyfile
self.action.args.skymask = skymask
self.action.args.contsky = contsky
self.action.args.cont_source_pos = cont_source_pos
self.action.args.cont_source_width = cont_source_width
# Do we have a sky alternate?
if skyfile:
# Generate sky file name
msname = skyfile.split('.fits')[0] + '_' + suffix + '.fits'
mskyf = os.path.join(rdir, msname)
# Does it exist?
if os.path.exists(mskyf):
self.logger.info("Master sky already exists: %s" % mskyf)
return False
else:
self.logger.warning("Alternate master sky %s not found."
% mskyf)
return True
else:
self.logger.info("No alternate master sky requested.")
return True
def _perform(self):
"""
Returns an Argument() with the parameters that depends on this operation
"""
self.logger.info("Creating master sky")
suffix = 'sky'
# get root for maps
tab = self.context.proctab.search_proctab(
frame=self.action.args.ccddata, target_type='MARC',
target_group=self.action.args.groupid)
if len(tab) <= 0:
self.logger.error("Geometry not solved!")
return self.action.args
groot = strip_fname(tab['filename'][0])
# Wavelength map image
wmf = groot + '_wavemap.fits'
self.logger.info("Reading image: %s" % wmf)
wavemap = kcwi_fits_reader(os.path.join(
self.config.instrument.cwd, 'redux', wmf))[0]
# Slice map image
slf = groot + '_slicemap.fits'
self.logger.info("Reading image: %s" % slf)
slicemap = kcwi_fits_reader(os.path.join(
self.config.instrument.cwd, 'redux', slf))[0]
# Position map image
pof = groot + '_posmap.fits'
self.logger.info("Reading image: %s" % pof)
posmap = kcwi_fits_reader(os.path.join(
self.config.instrument.cwd, 'redux', pof))[0]
posmax = np.nanmax(posmap.data)
posbuf = int(10. / self.action.args.xbinsize)
ny = posmap.data.shape[0]
if self.action.args.camera == 1: # Red
# Get ymap for trimming junk at ends
ymap = posmap.copy()
for i in range(ny):
ymap.data[i, :] = float(i)
else:
ymap = None
# wavelength region
wavegood0 = wavemap.header['WAVGOOD0']
wavegood1 = wavemap.header['WAVGOOD1']
waveall0 = wavemap.header['WAVALL0']
waveall1 = wavemap.header['WAVALL1']
wavemid = wavemap.header['WAVMID']
# get image size
sm_sz = self.action.args.ccddata.data.shape
# sky masking
# default is no masking (True = mask, False = don't mask)
binary_mask = np.zeros(sm_sz, dtype=bool)
# was sky masking requested?
if self.action.args.skymask:
if os.path.exists(self.action.args.skymask):
self.logger.info("Reading sky mask file: %s"
% self.action.args.skymask)
hdul = fits.open(self.action.args.skymask)
binary_mask = hdul[0].data
# verify size match
bm_sz = binary_mask.shape
if bm_sz[0] != sm_sz[0] or bm_sz[1] != sm_sz[1]:
self.logger.warning("Sky mask size mis-match: "
"masking disabled")
binary_mask = np.zeros(sm_sz, dtype=bool)
else:
self.logger.warning("Sky mask image not found: %s"
% self.action.args.skymask)
auto_masked = False
auto_mask_type = ""
auto_cont_pos = None
auto_cont_width = None
# if we are a standard, get mask for bright continuum source
if self.action.args.stdname is not None:
self.logger.info("Standard star observation of "
"%s will be auto-masked" %
self.action.args.stdname)
auto_masked = True
auto_mask_type = "Std Star"
# Use 10% of wavelength range at wavemid
std_wav_ran = (wavemid - 0.05 * (wavegood1 - wavegood0),
wavemid + 0.05 * (wavegood1 - wavegood0))
std_sl_max = -1
std_sl_sig_max = -1.
std_sl_max_pos_data = None
std_sl_max_flx_data = None
self.logger.info("Finding the std max slice")
for si in range(24):
sq = [i for i, v in enumerate(slicemap.data.flat) if v == si and
std_wav_ran[0] < wavemap.data.flat[i] < std_wav_ran[1] and
posbuf < posmap.data.flat[i] < (posmax - posbuf)]
xplt = posmap.data.flat[sq]
yplt = self.action.args.ccddata.data.flat[sq]
sig = float(np.nanstd(yplt))
self.logger.info("Slice %d - StDev = %.2f" % (si, sig))
if sig > std_sl_sig_max:
std_sl_sig_max = sig
std_sl_max = si
std_sl_max_pos_data = xplt.copy()
std_sl_max_flx_data = yplt.copy()
ipk = np.argmax(std_sl_max_flx_data)
ppk = std_sl_max_pos_data[ipk]
fpk = std_sl_max_flx_data[ipk]
# gaussian fit to max slice
res, _ = curve_fit(gaus, std_sl_max_pos_data, std_sl_max_flx_data,
p0=[fpk, ppk, 1.])
self.logger.info("Std max at %.2f in slice %d with width %.2f px"
% (res[1], std_sl_max, res[2]))
std_pos_mask_0 = res[1] - 5. * res[2]
std_pos_mask_1 = res[1] + 5. * res[2]
self.logger.info("Masking between %.2f and %.2f" %
(std_pos_mask_0, std_pos_mask_1))
auto_cont_pos = res[1]
auto_cont_width = 5. * res[2]
# Mask standard from sky calculation
for i, v in enumerate(binary_mask.flat):
if std_pos_mask_0 < posmap.data.flat[i] < std_pos_mask_1:
binary_mask.flat[i] = True
# plot, if requested
if self.config.instrument.plot_level >= 1:
xx = np.arange(np.min(std_sl_max_pos_data),
np.max(std_sl_max_pos_data), 1)
yy = gaus(xx, res[0], res[1], res[2])
p = figure(
title=self.action.args.plotlabel +
' Std max sl %d' % std_sl_max,
x_axis_label='Pos (x px)',
y_axis_label='Flux (e-)',
plot_width=self.config.instrument.plot_width,
plot_height=self.config.instrument.plot_height)
p.circle(std_sl_max_pos_data, std_sl_max_flx_data, size=1,
line_alpha=0., fill_color='purple',
legend_label='Data')
p.line([ppk, ppk], [0, fpk], color='green')
p.line([std_pos_mask_0, std_pos_mask_0], [0, fpk], color='blue')
p.line([std_pos_mask_1, std_pos_mask_1], [0, fpk], color='blue')
p.line(xx, yy, color='red')
bokeh_plot(p, self.context.bokeh_session)
if self.config.instrument.plot_level >= 2:
input("Next? <cr>: ")
else:
time.sleep(self.config.instrument.plot_pause)
# if we are a continuum source,
# get local mask for bright continuum source
elif self.action.args.contsky:
self.logger.info("continuum source observation will be auto-masked")
auto_masked = True
# Use 10% of wavelength range at wavemid
con_wav_ran = (wavemid - 0.05 * (wavegood1 - wavegood0),
wavemid + 0.05 * (wavegood1 - wavegood0))
con_sl_max = -1
con_sl_sig_max = -1.
con_sl_max_pos_data = None
con_sl_max_flx_data = None
if self.action.args.cont_source_pos is None:
self.logger.info("Finding the continuum source automatically")
auto_mask_type = "AutoCont"
for si in range(24):
sq = [i for i, v in enumerate(slicemap.data.flat) if
v == si and
con_wav_ran[0] < wavemap.data.flat[i] <
con_wav_ran[1] and
posbuf < posmap.data.flat[i] < (posmax - posbuf)]
xplt = posmap.data.flat[sq]
yplt = self.action.args.ccddata.data.flat[sq]
sig = float(np.nanstd(yplt))
self.logger.info("Slice %d - StDev = %.2f" % (si, sig))
if sig > con_sl_sig_max:
con_sl_sig_max = sig
con_sl_max = si
con_sl_max_pos_data = xplt.copy()
con_sl_max_flx_data = yplt.copy()
ipk = np.argmax(con_sl_max_flx_data)
ppk = con_sl_max_pos_data[ipk]
fpk = con_sl_max_flx_data[ipk]
# gaussian fit to max slice
res, _ = curve_fit(gaus, con_sl_max_pos_data,
con_sl_max_flx_data, p0=[fpk, ppk, 1.])
self.logger.info("Continuum source max at %.2f in "
"slice %d with width %.2f px"
% (res[1], con_sl_max, res[2]))
# First define source extent
con_pos_mask_0 = res[1] - 7. * res[2]
con_pos_mask_1 = res[1] + 7. * res[2]
auto_cont_pos = res[1]
auto_cont_width = 7. * res[2]
# Next define lower and upper windows
con_pos_mask_lo_0 = con_pos_mask_0 - \
14 / self.action.args.xbinsize
con_pos_mask_up_1 = con_pos_mask_1 + \
14 / self.action.args.xbinsize
# plot, if requested
if self.config.instrument.plot_level >= 1:
xx = np.arange(np.min(con_sl_max_pos_data),
np.max(con_sl_max_pos_data), 1)
yy = gaus(xx, res[0], res[1], res[2])
p = figure(
title=self.action.args.plotlabel +
' Cont source max sl %d' % con_sl_max,
x_axis_label='Pos (x px)',
y_axis_label='Flux (e-)',
plot_width=self.config.instrument.plot_width,
plot_height=self.config.instrument.plot_height)
p.circle(con_sl_max_pos_data, con_sl_max_flx_data,
size=1, line_alpha=0., fill_color='purple',
legend_label='Data')
p.line([ppk, ppk], [0, fpk], color='green')
p.line([con_pos_mask_lo_0, con_pos_mask_lo_0], [0, fpk],
color='blue')
p.line([con_pos_mask_0, con_pos_mask_0], [0, fpk],
color='blue')
p.line([con_pos_mask_1, con_pos_mask_1], [0, fpk],
color='blue')
p.line([con_pos_mask_up_1, con_pos_mask_up_1], [0, fpk],
color='blue')
p.line(xx, yy, color='red')
bokeh_plot(p, self.context.bokeh_session)
if self.config.instrument.plot_level >= 2:
input("Next? <cr>: ")
else:
time.sleep(self.config.instrument.plot_pause)
else:
self.logger.info("Using input source position of %.2f and"
"source width of %.2f" %
(self.action.args.cont_source_pos,
self.action.args.cont_source_width))
auto_mask_type = "UserCont"
auto_cont_pos = self.action.args.cont_source_pos
auto_cont_width = self.action.args.cont_source_width
# First define source extent
con_pos_mask_0 = auto_cont_pos - auto_cont_width
con_pos_mask_1 = auto_cont_pos + auto_cont_width
# Next define lower and upper windows
con_pos_mask_lo_0 = con_pos_mask_0 - \
14 / self.action.args.xbinsize
con_pos_mask_up_1 = con_pos_mask_1 + \
14 / self.action.args.xbinsize
self.logger.info("Masking all but sky region between "
"%.2f and %.2f and between %.2f and %.2f" %
(con_pos_mask_lo_0, con_pos_mask_0,
con_pos_mask_1, con_pos_mask_up_1))
# Mask all but local sky from sky calculation
for i, v in enumerate(binary_mask.flat):
if (0 < posmap.data.flat[i] < con_pos_mask_lo_0) or \
(con_pos_mask_0 < posmap.data.flat[i] < con_pos_mask_1) or \
(con_pos_mask_up_1 < posmap.data.flat[i] < posmax):
binary_mask.flat[i] = True
# count masked pixels
tmsk = len(np.nonzero(np.where(binary_mask.flat, True, False))[0])
self.logger.info("Number of pixels masked = %d" % tmsk)
finiteflux = np.isfinite(self.action.args.ccddata.data.flat)
# get un-masked points mapped to exposed regions on CCD
# handle dichroic bad region
if self.action.args.dich:
if self.action.args.camera == 0: # Blue
q = [i for i, v in enumerate(slicemap.data.flat)
if 0 <= v <= 23 and
posbuf < posmap.data.flat[i] < (posmax - posbuf) and
waveall0 <= wavemap.data.flat[i] <= waveall1 and
not (v > 20 and wavemap.data.flat[i] > 5600.) and
finiteflux[i] and not binary_mask.flat[i]]
else: # Red
q = [i for i, v in enumerate(slicemap.data.flat)
if 0 <= v <= 23 and
posbuf < posmap.data.flat[i] < (posmax - posbuf) and
waveall0 <= wavemap.data.flat[i] <= waveall1 and
not (v > 20 and wavemap.data.flat[i] < 5600.) and
finiteflux[i] and not binary_mask.flat[i] and
50 <= ymap.data.flat[i] <= (ny - 50)]
else:
if self.action.args.camera == 0: # Blue
q = [i for i, v in enumerate(slicemap.data.flat)
if 0 <= v <= 23 and
posbuf < posmap.data.flat[i] < (posmax - posbuf) and
waveall0 <= wavemap.data.flat[i] <= waveall1 and
finiteflux[i] and not binary_mask.flat[i]]
else:
q = [i for i, v in enumerate(slicemap.data.flat)
if 0 <= v <= 23 and
posbuf < posmap.data.flat[i] < (posmax - posbuf) and
waveall0 <= wavemap.data.flat[i] <= waveall1 and
finiteflux[i] and not binary_mask.flat[i] and
50 <= ymap.data.flat[i] <= (ny - 50)]
# get all points mapped to exposed regions on the CCD (for output)
qo = [i for i, v in enumerate(slicemap.data.flat)
if 0 <= v <= 23 and posmap.data.flat[i] >= 0 and
waveall0 <= wavemap.data.flat[i] <= waveall1 and
finiteflux[i]]
# extract relevant image values
fluxes = self.action.args.ccddata.data.flat[q]
# relevant wavelengths
waves = wavemap.data.flat[q]
self.logger.info("Number of fit waves = %d" % len(waves))
# keep output wavelengths
owaves = wavemap.data.flat[qo]
self.logger.info("Number of output waves = %d" % len(owaves))
# sort on wavelength
s = np.argsort(waves)
waves = waves[s]
fluxes = fluxes[s]
# knots per pixel
knotspp = self.config.instrument.KNOTSPP
n = int(sm_sz[0] * knotspp)
# calculate break points for b splines
bkpt = np.min(waves) + np.arange(n+1) * \
(np.max(waves) - np.min(waves)) / n
# log
self.logger.info("Nknots = %d, min = %.2f, max = %.2f (A)" %
(n, np.min(bkpt), np.max(bkpt)))
# do bspline fit
sft0, gmask = Bspline.iterfit(waves, fluxes, fullbkpt=bkpt,
upper=1, lower=1)
gp = [i for i, v in enumerate(gmask) if v]
yfit1, _ = sft0.value(waves)
self.logger.info("Number of good points = %d" % len(gp))
# check result
if np.max(yfit1) < 0:
self.logger.warning("B-spline failure")
if n > 2000:
if n == 5000:
n = 2000
if n == 8000:
n = 5000
# calculate breakpoints
bkpt = np.min(waves) + np.arange(n + 1) * \
(np.max(waves) - np.min(waves)) / n
# log
self.logger.info("Nknots = %d, min = %.2f, max = %.2f (A)" %
(n, np.min(bkpt), np.max(bkpt)))
# do bspline fit
sft0, gmask = Bspline.iterfit(waves, fluxes, fullbkpt=bkpt,
upper=1, lower=1)
yfit1, _ = sft0.value(waves)
if np.max(yfit1) <= 0:
self.logger.warning("B-spline final failure, sky is zero")
# get values at original wavelengths
yfit, _ = sft0.value(owaves)
# for plotting
gwaves = waves[gp]
gfluxes = fluxes[gp]
npts = len(gwaves)
stride = int(npts / 8000.)
xplt = gwaves[::stride]
yplt = gfluxes[::stride]
fplt, _ = sft0.value(xplt)
yrng = [np.min(yplt), np.max(yplt)]
self.logger.info("Stride = %d" % stride)
# plot, if requested
if self.config.instrument.plot_level >= 1:
# output filename stub
skyfnam = "sky_%05d_%s_%s_%s" % \
(self.action.args.ccddata.header['FRAMENO'],
self.action.args.illum, self.action.args.grating,
self.action.args.ifuname)
p = figure(
title=self.action.args.plotlabel + ' Master Sky',
x_axis_label='Wave (A)',
y_axis_label='Flux (e-)',
plot_width=self.config.instrument.plot_width,
plot_height=self.config.instrument.plot_height)
p.circle(xplt, yplt, size=1, line_alpha=0., fill_color='purple',
legend_label='Data')
p.line(xplt, fplt, line_color='red', legend_label='Fit')
p.line([wavegood0, wavegood0], yrng, line_color='green')
p.line([wavegood1, wavegood1], yrng, line_color='green')
p.y_range.start = yrng[0]
p.y_range.end = yrng[1]
bokeh_plot(p, self.context.bokeh_session)
if self.config.instrument.plot_level >= 2:
input("Next? <cr>: ")
else:
time.sleep(self.config.instrument.plot_pause)
save_plot(p, filename=skyfnam+".png")
# create sky image
sky = np.zeros(self.action.args.ccddata.data.shape, dtype=float)
sky.flat[qo] = yfit
# store original data, header
img = self.action.args.ccddata.data
hdr = self.action.args.ccddata.header.copy()
self.action.args.ccddata.data = sky
# get master sky output name
ofn_full = self.action.args.name
ofn = os.path.basename(ofn_full)
msname = strip_fname(ofn) + '_' + suffix + '.fits'
log_string = MakeMasterSky.__module__
self.action.args.ccddata.header['IMTYPE'] = 'SKY'
self.action.args.ccddata.header['HISTORY'] = log_string
self.action.args.ccddata.header['SKYMODEL'] = (True, 'sky model image?')
self.action.args.ccddata.header['SKYIMAGE'] = \
(ofn, 'image used for sky model')
if tmsk > 0:
self.action.args.ccddata.header['SKYMSK'] = (True,
'was sky masked?')
if auto_masked:
self.action.args.ccddata.header['AUTOMASK'] = (True,
'auto-masked?')
self.action.args.ccddata.header['AUTMSKTY'] = (
auto_mask_type, 'Type of auto-masking')
self.action.args.ccddata.header['CONTPOS'] = (
auto_cont_pos, 'Position in slice of continuum')
self.action.args.ccddata.header['CONTWID'] = (
auto_cont_width, 'Width of continuum source')
# self.action.args.ccddata.header['SKYMSKF'] = (skymf,
# 'sky mask file')
else:
self.action.args.ccddata.header['SKYMSK'] = (False,
'was sky masked?')
self.action.args.ccddata.header['WAVMAPF'] = wmf
self.action.args.ccddata.header['SLIMAPF'] = slf
self.action.args.ccddata.header['POSMAPF'] = pof
# output master sky
kcwi_fits_writer(self.action.args.ccddata, output_file=msname,
output_dir=self.config.instrument.output_directory)
self.context.proctab.update_proctab(frame=self.action.args.ccddata,
suffix=suffix,
newtype="SKY",
filename=self.action.args.name)
self.context.proctab.write_proctab(tfil=self.config.instrument.procfile)
# restore original image
self.action.args.ccddata.data = img
self.action.args.ccddata.header = hdr
self.logger.info(log_string)
return self.action.args
# END: class MakeMasterSky()