Source code for legacypipe.oneblob

from __future__ import print_function

import numpy as np
import pylab as plt
import time

from astrometry.util.ttime import Time, CpuMeas
from astrometry.util.resample import resample_with_wcs, OverlapError
from astrometry.util.fits import fits_table
from astrometry.util.plotutils import dimshow

from tractor import Tractor, PointSource, Image, NanoMaggies, Catalog, Patch
from tractor.galaxy import DevGalaxy, ExpGalaxy, FixedCompositeGalaxy, SoftenedFracDev, FracDev, disable_galaxy_cache, enable_galaxy_cache
from tractor.patch import ModelMask

from legacypipe.survey import (SimpleGalaxy, RexGalaxy, GaiaSource,
                               LegacyEllipseWithPriors, get_rgb, IN_BLOB)
from legacypipe.runbrick import rgbkwargs, rgbkwargs_resid
from legacypipe.coadds import quick_coadds
from legacypipe.runbrick_plots import _plot_mods

[docs]def one_blob(X): ''' Fits sources contained within a "blob" of pixels. ''' if X is None: return None (nblob, iblob, Isrcs, brickwcs, bx0, by0, blobw, blobh, blobmask, timargs, srcs, bands, plots, ps, simul_opt, use_ceres, rex, refs) = X print('Fitting blob number', nblob, 'val', iblob, ':', len(Isrcs), 'sources, size', blobw, 'x', blobh, len(timargs), 'images') if len(timargs) == 0: return None hasbright = refs is not None and np.any(refs.isbright) hasmedium = refs is not None and np.any(refs.ismedium) if plots: plt.figure(2, figsize=(3,3)) plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99) plt.figure(1) t0 = time.clock() # A local WCS for this blob blobwcs = brickwcs.get_subimage(bx0, by0, blobw, blobh) # Per-source measurements for this blob B = fits_table() B.sources = srcs B.Isrcs = Isrcs B.iblob = iblob B.blob_x0 = np.zeros(len(B), np.int16) + bx0 B.blob_y0 = np.zeros(len(B), np.int16) + by0 # Did sources start within the blob? ok,x0,y0 = blobwcs.radec2pixelxy( np.array([src.getPosition().ra for src in srcs]), np.array([src.getPosition().dec for src in srcs])) B.started_in_blob = blobmask[ np.clip(np.round(y0-1).astype(int), 0,blobh-1), np.clip(np.round(x0-1).astype(int), 0,blobw-1)] B.cpu_source = np.zeros(len(B), np.float32) B.blob_width = np.zeros(len(B), np.int16) + blobw B.blob_height = np.zeros(len(B), np.int16) + blobh B.blob_npix = np.zeros(len(B), np.int32) + np.sum(blobmask) B.blob_nimages= np.zeros(len(B), np.int16) + len(timargs) B.blob_symm_width = np.zeros(len(B), np.int16) B.blob_symm_height = np.zeros(len(B), np.int16) B.blob_symm_npix = np.zeros(len(B), np.int32) B.blob_symm_nimages = np.zeros(len(B), np.int16) B.hit_limit = np.zeros(len(B), bool) ob = OneBlob('%i'%(nblob+1), blobwcs, blobmask, timargs, srcs, bands, plots, ps, simul_opt, use_ceres, hasbright, hasmedium, rex) ob.run(B) B.blob_totalpix = np.zeros(len(B), np.int32) + ob.total_pix ok,x1,y1 = blobwcs.radec2pixelxy( np.array([src.getPosition().ra for src in B.sources]), np.array([src.getPosition().dec for src in B.sources])) B.finished_in_blob = blobmask[ np.clip(np.round(y1-1).astype(int), 0, blobh-1), np.clip(np.round(x1-1).astype(int), 0, blobw-1)] assert(len(B.finished_in_blob) == len(B)) assert(len(B.finished_in_blob) == len(B.started_in_blob)) B.brightblob = np.zeros(len(B), np.int16) if hasbright: B.brightblob += IN_BLOB['BRIGHT'] if hasmedium: B.brightblob += IN_BLOB['MEDIUM'] if refs is not None and 'iscluster' in refs.get_columns() and np.any(refs.iscluster): B.brightblob += IN_BLOB['CLUSTER'] B.cpu_blob = np.zeros(len(B), np.float32) t1 = time.clock() B.cpu_blob[:] = t1 - t0 return B
class OneBlob(object): def __init__(self, name, blobwcs, blobmask, timargs, srcs, bands, plots, ps, simul_opt, use_ceres, hasbright, hasmedium, rex): self.name = name self.rex = rex self.blobwcs = blobwcs self.pixscale = self.blobwcs.pixel_scale() self.blobmask = blobmask self.srcs = srcs self.bands = bands self.plots = plots self.plots_per_source = plots self.plots_per_model = False # blob-1-data.png, etc self.plots_single = False self.ps = ps self.simul_opt = simul_opt self.use_ceres = use_ceres self.hasbright = hasbright self.hasmedium = hasmedium self.tims = self.create_tims(timargs) self.total_pix = sum([np.sum(t.getInvError() > 0) for t in self.tims]) self.plots2 = False alphas = [0.1, 0.3, 1.0] self.optargs = dict(priors=True, shared_params=False, alphas=alphas, print_progress=True) self.blobh,self.blobw = blobmask.shape self.bigblob = (self.blobw * self.blobh) > 100*100 if self.bigblob: print('Big blob:', name) self.trargs = dict() # if use_ceres: # from tractor.ceres_optimizer import CeresOptimizer # ceres_optimizer = CeresOptimizer() # self.optargs.update(scale_columns=False, # scaled=False, # dynamic_scale=False) # self.trargs.update(optimizer=ceres_optimizer) # else: # self.optargs.update(dchisq = 0.1) from legacypipe.constrained_optimizer import ConstrainedOptimizer self.trargs.update(optimizer=ConstrainedOptimizer()) self.optargs.update(dchisq = 0.1) def run(self, B): # Not quite so many plots... self.plots1 = self.plots cat = Catalog(*self.srcs) tlast = Time() if self.plots: self._initial_plots() if not self.bigblob: print('Fitting just fluxes using initial models...') self._fit_fluxes(cat, self.tims, self.bands) tr = self.tractor(self.tims, cat) if self.plots: self._plots(tr, 'Initial models') # Optimize individual sources, in order of flux. # First, choose the ordering... Ibright = _argsort_by_brightness(cat, self.bands) if len(cat) > 1: self._optimize_individual_sources_subtract( cat, Ibright, B.cpu_source) else: self._optimize_individual_sources(tr, cat, Ibright, B.cpu_source) # Optimize all at once? if len(cat) > 1 and len(cat) <= 10: #tfit = Time() cat.thawAllParams() tr.optimize_loop(**self.optargs) if self.plots: self._plots(tr, 'After source fitting') plt.clf() self._plot_coadd(self.tims, self.blobwcs, model=tr) plt.title('After source fitting') self.ps.savefig() if self.plots_single: plt.figure(2) mods = list(tr.getModelImages()) coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, images=mods, fill_holes=False) dimshow(get_rgb(coimgs,self.bands), ticks=False) plt.savefig('blob-%s-initmodel.png' % (self.name)) res = [(tim.getImage() - mod) for tim,mod in zip(self.tims, mods)] coresids,nil = quick_coadds(self.tims, self.bands, self.blobwcs, images=res) dimshow(get_rgb(coresids, self.bands, **rgbkwargs_resid), ticks=False) plt.savefig('blob-%s-initresid.png' % (self.name)) dimshow(get_rgb(coresids, self.bands), ticks=False) plt.savefig('blob-%s-initsub.png' % (self.name)) plt.figure(1) print('Blob', self.name, 'finished initial fitting:', Time()-tlast) tlast = Time() # Next, model selections: point source vs dev/exp vs composite. self.run_model_selection(cat, Ibright, B) print('Blob', self.name, 'finished model selection:', Time()-tlast) tlast = Time() if self.plots: self._plots(tr, 'After model selection') if self.plots_single: plt.figure(2) mods = list(tr.getModelImages()) coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, images=mods, fill_holes=False) dimshow(get_rgb(coimgs,self.bands), ticks=False) plt.savefig('blob-%s-model.png' % (self.name)) res = [(tim.getImage() - mod) for tim,mod in zip(self.tims, mods)] coresids,nil = quick_coadds(self.tims, self.bands, self.blobwcs, images=res) dimshow(get_rgb(coresids, self.bands, **rgbkwargs_resid), ticks=False) plt.savefig('blob-%s-resid.png' % (self.name)) plt.figure(1) # Cut down to just the kept sources I = np.array([i for i,s in enumerate(cat) if s is not None]) B.cut(I) cat = Catalog(*B.sources) tr.catalog = cat # Do another quick round of flux-only fitting? # This does horribly -- fluffy galaxies go out of control because # they're only constrained by pixels within this blob. #_fit_fluxes(cat, tims, bands, use_ceres, alphas) # ### Simultaneous re-opt? # if simul_opt and len(cat) > 1 and len(cat) <= 10: # #tfit = Time() # cat.thawAllParams() # #print('Optimizing:', tr) # #tr.printThawedParams() # max_cpu = 300. # cpu0 = time.clock() # for step in range(50): # dlnp,X,alpha = tr.optimize(**optargs) # cpu = time.clock() # if cpu-cpu0 > max_cpu: # print('Warning: Exceeded maximum CPU time for source') # break # if dlnp < 0.1: # break # #print('Simultaneous fit took:', Time()-tfit) # Compute variances on all parameters for the kept model B.srcinvvars = [None for i in range(len(B))] cat.thawAllRecursive() cat.freezeAllParams() for isub in range(len(B.sources)): cat.thawParam(isub) src = cat[isub] if src is None: cat.freezeParam(isub) continue # Convert to "vanilla" ellipse parameterization nsrcparams = src.numberOfParams() _convert_ellipses(src) assert(src.numberOfParams() == nsrcparams) # print('Computing variances for source', src, ': N params:', nsrcparams) # print('Source params:') # src.printThawedParams() # For Gaia sources, temporarily convert the GaiaPosition to a # RaDecPos in order to compute the invvar it would have in our # imaging? Or just plug in the Gaia-measured uncertainties?? # (going to implement the latter) # Compute inverse-variances allderivs = tr.getDerivs() ivars = _compute_invvars(allderivs) assert(len(ivars) == nsrcparams) #print('Inverse-variances:', ivars) B.srcinvvars[isub] = ivars assert(len(B.srcinvvars[isub]) == cat[isub].numberOfParams()) cat.freezeParam(isub) # Check for sources with zero inverse-variance -- I think these # can be generated during the "Simultaneous re-opt" stage above -- # sources can get scattered outside the blob. I, = np.nonzero([np.sum(iv) > 0 for iv in B.srcinvvars]) if len(I) < len(B): print('Keeping', len(I), 'of', len(B),'sources with non-zero ivar') B.cut(I) cat = Catalog(*B.sources) tr.catalog = cat M = _compute_source_metrics(B.sources, self.tims, self.bands, tr) for k,v in M.items(): B.set(k, v) print('Blob', self.name, 'finished:', Time()-tlast) def run_model_selection(self, cat, Ibright, B): # We compute & subtract initial models for the other sources while # fitting each source: # -Remember the original images # -Compute initial models for each source (in each tim) # -Subtract initial models from images # -During fitting, for each source: # -add back in the source's initial model (to each tim) # -fit, with Catalog([src]) # -subtract final model (from each tim) # -Replace original images models = SourceModels() # Remember original tim images models.save_images(self.tims) # Create initial models for each tim x each source models.create(self.tims, cat, subtract=True) N = len(cat) B.dchisq = np.zeros((N, 5), np.float32) B.all_models = np.array([{} for i in range(N)]) B.all_model_ivs = np.array([{} for i in range(N)]) B.all_model_cpu = np.array([{} for i in range(N)]) B.all_model_hit_limit = np.array([{} for i in range(N)]) # Model selection for sources, in decreasing order of brightness for numi,srci in enumerate(Ibright): src = cat[srci] print('Model selection for source %i of %i in blob %s; sourcei %i' % (numi+1, len(Ibright), self.name, srci)) cpu0 = time.clock() # Add this source's initial model back in. models.add(srci, self.tims) if self.plots_single: plt.figure(2) tr = self.tractor(self.tims, cat) coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, fill_holes=False) rgb = get_rgb(coimgs,self.bands) plt.imsave('blob-%s-%s-bdata.png' % (self.name, srci), rgb, origin='lower') plt.figure(1) keepsrc = self.model_selection_one_source(src, srci, models, B) B.sources[srci] = keepsrc cat[srci] = keepsrc # Re-remove the final fit model for this source. models.update_and_subtract(srci, keepsrc, self.tims) if self.plots_single: plt.figure(2) tr = self.tractor(self.tims, cat) coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, fill_holes=False) dimshow(get_rgb(coimgs,self.bands), ticks=False) plt.savefig('blob-%s-%i-sub.png' % (self.name, srci)) plt.figure(1) cpu1 = time.clock() B.cpu_source[srci] += (cpu1 - cpu0) models.restore_images(self.tims) del models def model_selection_one_source(self, src, srci, models, B): # Fit local constant sky background levels if we're in the # same blob as a medium-brightness star. fit_background = self.hasmedium if self.bigblob: mods = [mod[srci] for mod in models.models] srctims,modelMasks = _get_subimages(self.tims, mods, src) # Create a little local WCS subregion for this source, by # resampling non-zero inverrs from the srctims into blobwcs insrc = np.zeros((self.blobh,self.blobw), bool) for tim in srctims: try: Yo,Xo,Yi,Xi,nil = resample_with_wcs( self.blobwcs, tim.subwcs, [],2) except: continue insrc[Yo,Xo] |= (tim.inverr[Yi,Xi] > 0) if np.sum(insrc) == 0: # No source pixels touching blob... this can # happen when a source scatters outside the blob # in the fitting stage. Drop the source here. return None yin = np.max(insrc, axis=1) xin = np.max(insrc, axis=0) yl,yh = np.flatnonzero(yin)[np.array([0,-1])] xl,xh = np.flatnonzero(xin)[np.array([0,-1])] del insrc srcwcs = self.blobwcs.get_subimage(xl, yl, 1+xh-xl, 1+yh-yl) srcwcs_x0y0 = (xl, yl) # A mask for which pixels in the 'srcwcs' square are occupied. srcblobmask = self.blobmask[yl:yh+1, xl:xh+1] else: modelMasks = models.model_masks(srci, src) srctims = self.tims srcwcs = self.blobwcs srcwcs_x0y0 = (0, 0) srcblobmask = self.blobmask if self.plots_per_source: # This is a handy blob-coordinates plot of the data # going into the fit. plt.clf() nil,nil,coimgs,nil = quick_coadds(srctims, self.bands,self.blobwcs, fill_holes=False, get_cow=True) dimshow(get_rgb(coimgs, self.bands)) ax = plt.axis() pos = src.getPosition() ok,x,y = self.blobwcs.radec2pixelxy(pos.ra, pos.dec) ix,iy = int(np.round(x-1)), int(np.round(y-1)) plt.plot(x-1, y-1, 'r+') plt.axis(ax) plt.title('Model selection: stage1 data') self.ps.savefig() # Mask out other sources while fitting this one, by # finding symmetrized blobs of significant pixels mask_others = True if mask_others: from legacypipe.detection import detection_maps from astrometry.util.multiproc import multiproc from scipy.ndimage.morphology import binary_dilation from scipy.ndimage.measurements import label, find_objects # Compute per-band detection maps mp = multiproc() detmaps,detivs,satmaps = detection_maps( srctims, srcwcs, self.bands, mp) # Compute the symmetric area that fits in this 'tim' pos = src.getPosition() ok,xx,yy = srcwcs.radec2pixelxy(pos.ra, pos.dec) bh,bw = srcblobmask.shape ix = int(np.clip(np.round(xx-1), 0, bw-1)) iy = int(np.clip(np.round(yy-1), 0, bh-1)) flipw = min(ix, bw-1-ix) fliph = min(iy, bh-1-iy) flipblobs = np.zeros(srcblobmask.shape, bool) # Go through the per-band detection maps, marking significant pixels for i,(detmap,detiv) in enumerate(zip(detmaps,detivs)): sn = detmap * np.sqrt(detiv) slc = (slice(iy-fliph, iy+fliph+1), slice(ix-flipw, ix+flipw+1)) flipsn = np.zeros_like(sn) # Symmetrize flipsn[slc] = np.minimum(sn[slc], np.flipud(np.fliplr(sn[slc]))) # just OR the detection maps per-band... flipblobs |= (flipsn > 5.) blobs,nb = label(flipblobs) goodblob = blobs[iy,ix] if goodblob != 0: flipblobs = (blobs == goodblob) dilated = binary_dilation(flipblobs, iterations=4) if not np.any(dilated): print('No pixels in dilated symmetric mask') return None yin = np.max(dilated, axis=1) xin = np.max(dilated, axis=0) yl,yh = np.flatnonzero(yin)[np.array([0,-1])] xl,xh = np.flatnonzero(xin)[np.array([0,-1])] #print('Dilated: good bounds x', xl,xh, 'y', yl,yh) #oldshape = srcwcs.shape (oldx0,oldy0) = srcwcs_x0y0 srcwcs = srcwcs.get_subimage(xl, yl, 1+xh-xl, 1+yh-yl) srcwcs_x0y0 = (oldx0 + xl, oldy0 + yl) srcblobmask = srcblobmask[yl:yh+1, xl:xh+1] #print('Cut srcwcs from', oldshape, 'to', srcwcs.shape) dilated = dilated[yl:yh+1, xl:xh+1] flipblobs = flipblobs[yl:yh+1, xl:xh+1] saved_srctim_ies = [] keep_srctims = [] mm = [] totalpix = 0 for tim in srctims: # Zero out inverse-errors for all pixels outside # 'dilated'. try: Yo,Xo,Yi,Xi,nil = resample_with_wcs( tim.subwcs, srcwcs, [], 2) except: continue ie = tim.getInvError() newie = np.zeros_like(ie) good, = np.nonzero(dilated[Yi,Xi] * (ie[Yo,Xo] > 0)) if len(good) == 0: print('Tim has inverr all == 0') continue yy = Yo[good] xx = Xo[good] newie[yy,xx] = ie[yy,xx] xl,xh = xx.min(), xx.max() yl,yh = yy.min(), yy.max() totalpix += len(xx) d = { src: ModelMask(xl, yl, 1+xh-xl, 1+yh-yl) } mm.append(d) saved_srctim_ies.append(ie) tim.inverr = newie keep_srctims.append(tim) srctims = keep_srctims modelMasks = mm B.blob_symm_nimages[srci] = len(srctims) B.blob_symm_npix[srci] = totalpix sh,sw = srcwcs.shape B.blob_symm_width [srci] = sw B.blob_symm_height[srci] = sh if self.plots_per_source: from legacypipe.detection import plot_boundary_map plt.clf() dimshow(get_rgb(coimgs, self.bands)) ax = plt.axis() plt.plot(x-1, y-1, 'r+') plt.axis(ax) sx0,sy0 = srcwcs_x0y0 sh,sw = srcwcs.shape ext = [sx0, sx0+sw, sy0, sy0+sh] plot_boundary_map(flipblobs, rgb=(255,255,255), extent=ext) plot_boundary_map(dilated, rgb=(0,255,0), extent=ext) plt.title('symmetrized blobs') self.ps.savefig() nil,nil,coimgs,nil = quick_coadds( srctims, self.bands, self.blobwcs, fill_holes=False, get_cow=True) # dimshow(get_rgb(coimgs, self.bands)) # ax = plt.axis() # plt.plot(x-1, y-1, 'r+') # plt.axis(ax) # plt.title('Symmetric-blob masked') # self.ps.savefig() # plt.clf() # for tim in srctims: # ie = tim.getInvError() # sigmas = (tim.getImage() * ie)[ie > 0] # plt.hist(sigmas, range=(-5,5), bins=21, histtype='step') # plt.axvline(np.mean(sigmas), alpha=0.5) # plt.axvline(0., color='k', lw=3, alpha=0.5) # plt.xlabel('Image pixels (sigma)') # plt.title('Symmetrized pixel values') # self.ps.savefig() # # plot the modelmasks for each tim. # plt.clf() # R = int(np.floor(np.sqrt(len(srctims)))) # C = int(np.ceil(len(srctims) / float(R))) # for i,tim in enumerate(srctims): # plt.subplot(R, C, i+1) # msk = modelMasks[i][src].mask # print('Mask:', msk) # if msk is None: # continue # plt.imshow(msk, interpolation='nearest', origin='lower', vmin=0, vmax=1) # plt.title(tim.name) # plt.suptitle('Model Masks') # self.ps.savefig() if self.bigblob and self.plots_per_source: # This is a local source-WCS plot of the data going into the # fit. plt.clf() coimgs,cons = quick_coadds(srctims, self.bands, srcwcs, fill_holes=False) dimshow(get_rgb(coimgs, self.bands)) plt.title('Model selection: stage1 data (srcwcs)') self.ps.savefig() #self._plots(srctractor, 'Model selection init') srctractor = self.tractor(srctims, [src]) srctractor.setModelMasks(modelMasks) srccat = srctractor.getCatalog() ok,ix,iy = srcwcs.radec2pixelxy(src.getPosition().ra, src.getPosition().dec) ix = int(ix-1) iy = int(iy-1) # Start in blob sh,sw = srcwcs.shape if ix < 0 or iy < 0 or ix >= sw or iy >= sh or not srcblobmask[iy,ix]: print('Source is starting outside blob -- skipping.') return None if fit_background: for tim in srctims: tim.freezeAllBut('sky') srctractor.thawParam('images') skyparams = srctractor.images.getParams() enable_galaxy_cache() # Compute the log-likehood without a source here. srccat[0] = None if fit_background: #print('Fitting no-source model (sky)') srctractor.optimize_loop(**self.optargs) #srctractor.images.printThawedParams() chisqs_none = _per_band_chisqs(srctractor, self.bands) nparams = dict(ptsrc=2, simple=2, rex=3, exp=5, dev=5, comp=9) # This is our "upgrade" threshold: how much better a galaxy # fit has to be versus ptsrc, and comp versus galaxy. galaxy_margin = 3.**2 + (nparams['exp'] - nparams['ptsrc']) # *chisqs* is actually chi-squared improvement vs no source; # larger is a better fit. chisqs = dict(none=0) oldmodel, ptsrc, simple, dev, exp, comp = _initialize_models( src, self.rex) if self.rex: simname = 'rex' rex = simple else: simname = 'simple' trymodels = [('ptsrc', ptsrc)] if oldmodel == 'ptsrc': forced = False if isinstance(src, GaiaSource): print('Gaia source', src) if src.isForcedPointSource(): forced = True if forced: print('Gaia source is forced to be a point source -- not trying other models') elif self.hasbright: print('Not computing galaxy models: bright star in blob') else: trymodels.append((simname, simple)) # Try galaxy models if simple > ptsrc, or if bright. # The 'gals' model is just a marker trymodels.append(('gals', None)) else: trymodels.extend([('dev', dev), ('exp', exp), ('comp', comp)]) cputimes = {} for name,newsrc in trymodels: cpum0 = time.clock() if name == 'gals': # If 'simple' was better than 'ptsrc', or the source is # bright, try the galaxy models. chi_sim = chisqs.get(simname, 0) chi_psf = chisqs.get('ptsrc', 0) if chi_sim > chi_psf or max(chi_psf, chi_sim) > 400: trymodels.extend([ ('dev', dev), ('exp', exp), ('comp', comp)]) continue if name == 'comp' and newsrc is None: # Compute the comp model if exp or dev would be accepted smod = _select_model(chisqs, nparams, galaxy_margin, self.rex) if smod not in ['dev', 'exp']: continue newsrc = comp = FixedCompositeGalaxy( src.getPosition(), src.getBrightness(), SoftenedFracDev(0.5), exp.getShape(), dev.getShape()).copy() srccat[0] = newsrc #print('Starting optimization for', name) # Set maximum galaxy model sizes # FIXME -- could use different fractions for deV vs exp (or comp) fblob = 0.8 sh,sw = srcwcs.shape rmax = np.log(fblob * max(sh, sw) * self.pixscale) if name in ['exp', 'rex', 'dev']: newsrc.shape.setMaxLogRadius(rmax) elif name in ['comp']: newsrc.shapeExp.setMaxLogRadius(rmax) newsrc.shapeDev.setMaxLogRadius(rmax) ### FIXME -- also set model rendering limits here?? # Use the same modelMask shapes as the original source ('src'). # Need to create newsrc->mask mappings though: mm = remap_modelmask(modelMasks, src, newsrc) srctractor.setModelMasks(mm) enable_galaxy_cache() # Save these modelMasks for later... newsrc_mm = mm #lnp = srctractor.getLogProb() #print('Initial log-prob:', lnp) #print('vs original src: ', lnp - lnp0) # if self.plots and False: # # Grid of derivatives. # _plot_derivs(tims, newsrc, srctractor, ps) # if self.plots: # mods = list(srctractor.getModelImages()) # plt.clf() # coimgs,cons = quick_coadds(srctims, bands, srcwcs, # images=mods, fill_holes=False) # dimshow(get_rgb(coimgs, bands)) # plt.title('Initial: ' + name) # self.ps.savefig() if fit_background: #print('Resetting sky params.') srctractor.images.setParams(skyparams) srctractor.thawParam('images') # First-round optimization (during model selection) #print('Optimizing: first round for', name, ':', len(srctims)) #print(newsrc) cpustep0 = time.clock() R = srctractor.optimize_loop(**self.optargs) #print('Optimizing first round', name, 'took', # time.clock()-cpustep0) print('Fit result:', newsrc) hit_limit = R.get('hit_limit', False) if hit_limit: if name in ['exp', 'rex', 'dev']: print('Hit limit: r %.2f vs %.2f' % (newsrc.shape.re, np.exp(rmax))) elif name in ['comp']: print('Hit limit: r %.2f, %.2f vs %.2f' % (newsrc.shapeExp.re, newsrc.shapeDev.re, np.exp(rmax))) #srctractor.printThawedParams() ok,ix,iy = srcwcs.radec2pixelxy(newsrc.getPosition().ra, newsrc.getPosition().dec) ix = int(ix-1) iy = int(iy-1) sh,sw = srcblobmask.shape if ix < 0 or iy < 0 or ix >= sw or iy >= sh or not srcblobmask[iy,ix]: # Exited blob! print('Source exited sub-blob!') # FIXME -- do we want to save any of the fitting results? # Or flag this?? continue disable_galaxy_cache() # Compute inverse-variances for each source. # Convert to "vanilla" ellipse parameterization # (but save old shapes first) # we do this (rather than making a copy) because we want to # use the same modelMask maps. if isinstance(newsrc, (DevGalaxy, ExpGalaxy)): oldshape = newsrc.shape elif isinstance(newsrc, FixedCompositeGalaxy): oldshape = (newsrc.shapeExp, newsrc.shapeDev,newsrc.fracDev) if fit_background: # We have to freeze the sky here before computing # uncertainties srctractor.freezeParam('images') nsrcparams = newsrc.numberOfParams() _convert_ellipses(newsrc) assert(newsrc.numberOfParams() == nsrcparams) # Compute inverse-variances # This uses the second-round modelMasks. allderivs = srctractor.getDerivs() ivars = _compute_invvars(allderivs) assert(len(ivars) == nsrcparams) B.all_model_ivs[srci][name] = np.array(ivars).astype(np.float32) B.all_models[srci][name] = newsrc.copy() assert(B.all_models[srci][name].numberOfParams() == nsrcparams) # Now revert the ellipses! if isinstance(newsrc, (DevGalaxy, ExpGalaxy)): newsrc.shape = oldshape elif isinstance(newsrc, FixedCompositeGalaxy): (newsrc.shapeExp, newsrc.shapeDev,newsrc.fracDev) = oldshape # Use the original 'srctractor' here so that the different # models are evaluated on the same pixels. # ---> AND with the same modelMasks as the original source... # srctractor.setModelMasks(newsrc_mm) ch = _per_band_chisqs(srctractor, self.bands) chisqs[name] = _chisq_improvement(newsrc, ch, chisqs_none) cpum1 = time.clock() B.all_model_cpu[srci][name] = cpum1 - cpum0 cputimes[name] = cpum1 - cpum0 B.all_model_hit_limit[srci][name] = hit_limit if mask_others: for ie,tim in zip(saved_srctim_ies, srctims): tim.inverr = ie # After model selection, revert the sky # (srctims=tims when not bigblob) if fit_background: srctractor.images.setParams(skyparams) # Actually select which model to keep. This "modnames" # array determines the order of the elements in the DCHISQ # column of the catalog. modnames = ['ptsrc', simname, 'dev', 'exp', 'comp'] keepmod = _select_model(chisqs, nparams, galaxy_margin, self.rex) keepsrc = {'none':None, 'ptsrc':ptsrc, simname:simple, 'dev':dev, 'exp':exp, 'comp':comp}[keepmod] bestchi = chisqs.get(keepmod, 0.) B.dchisq[srci, :] = np.array([chisqs.get(k,0) for k in modnames]) if keepsrc is not None and bestchi == 0.: # Weird edge case, or where some best-fit fluxes go # negative. eg # https://github.com/legacysurvey/legacypipe/issues/174 print('Best dchisq is 0 -- dropping source') keepsrc = None B.hit_limit[srci] = B.all_model_hit_limit[srci].get(keepmod, False) # This is the model-selection plot if self.plots_per_source: from collections import OrderedDict subplots = [] plt.clf() rows,cols = 3, 6 mods = OrderedDict([ ('none',None), ('ptsrc',ptsrc), (simname,simple), ('dev',dev), ('exp',exp), ('comp',comp)]) for imod,modname in enumerate(mods.keys()): if modname != 'none' and not modname in chisqs: continue srccat[0] = mods[modname] srctractor.setModelMasks(None) axes = [] plt.subplot(rows, cols, imod+1) if modname == 'none': # In the first panel, we show a coadd of the data coimgs, cons = quick_coadds(srctims, self.bands,srcwcs) rgbims = coimgs rgb = get_rgb(coimgs, self.bands) dimshow(rgb, ticks=False) subplots.append(('data', rgb)) axes.append(plt.gca()) ax = plt.axis() ok,x,y = srcwcs.radec2pixelxy( src.getPosition().ra, src.getPosition().dec) plt.plot(x-1, y-1, 'r+') plt.axis(ax) tt = 'Image' chis = [((tim.getImage()) * tim.getInvError())**2 for tim in srctims] res = [tim.getImage() for tim in srctims] else: modimgs = list(srctractor.getModelImages()) comods,nil = quick_coadds(srctims, self.bands, srcwcs, images=modimgs) rgbims = comods rgb = get_rgb(comods, self.bands) dimshow(rgb, ticks=False) axes.append(plt.gca()) subplots.append(('mod'+modname, rgb)) tt = modname #+ '\n(%.0f s)' % cputimes[modname] chis = [((tim.getImage() - mod) * tim.getInvError())**2 for tim,mod in zip(srctims, modimgs)] res = [(tim.getImage() - mod) for tim,mod in zip(srctims, modimgs)] # Second row: same rgb image with arcsinh stretch plt.subplot(rows, cols, imod+1+cols) dimshow(get_rgb(rgbims, self.bands, **rgbkwargs), ticks=False) axes.append(plt.gca()) plt.title(tt) # residuals coresids,nil = quick_coadds(srctims, self.bands, srcwcs, images=res) plt.subplot(rows, cols, imod+1+2*cols) rgb = get_rgb(coresids, self.bands, **rgbkwargs_resid) dimshow(rgb, ticks=False) axes.append(plt.gca()) subplots.append(('res'+modname, rgb)) plt.title('chisq %.0f' % chisqs[modname], fontsize=8) # Highlight the model to be kept if modname == keepmod: for ax in axes: for spine in ax.spines.values(): spine.set_edgecolor('red') spine.set_linewidth(2) plt.suptitle('Blob %s, source %i: keeping %s\nwas: %s' % (self.name, srci, keepmod, str(src)), fontsize=10) self.ps.savefig() if self.plots_single: for name,rgb in subplots: plt.figure(2) plt.subplots_adjust(left=0.01, right=0.99, bottom=0.01, top=0.99) dimshow(rgb, ticks=False) fn = 'blob-%s-%i-%s.png' % (self.name, srci, name) plt.savefig(fn) print('Wrote', fn) plt.figure(1) return keepsrc def _optimize_individual_sources(self, tr, cat, Ibright, cputime): # Single source (though this is coded to handle multiple sources) # Fit sources one at a time, but don't subtract other models cat.freezeAllParams() models = SourceModels() models.create(self.tims, cat) enable_galaxy_cache() for numi,i in enumerate(Ibright): cpu0 = time.clock() #print('Fitting source', i, '(%i of %i in blob)' % # (numi, len(Ibright))) cat.freezeAllBut(i) modelMasks = models.model_masks(0, cat[i]) tr.setModelMasks(modelMasks) tr.optimize_loop(**self.optargs) #print('Fitting source took', Time()-tsrc) # print(cat[i]) cpu1 = time.clock() cputime[i] += (cpu1 - cpu0) tr.setModelMasks(None) disable_galaxy_cache() def tractor(self, tims, cat): tr = Tractor(tims, cat, **self.trargs) tr.freezeParams('images') return tr def _optimize_individual_sources_subtract(self, cat, Ibright, cputime): # -Remember the original images # -Compute initial models for each source (in each tim) # -Subtract initial models from images # -During fitting, for each source: # -add back in the source's initial model (to each tim) # -fit, with Catalog([src]) # -subtract final model (from each tim) # -Replace original images models = SourceModels() # Remember original tim images models.save_images(self.tims) # Create & subtract initial models for each tim x each source models.create(self.tims, cat, subtract=True) # For sources, in decreasing order of brightness for numi,srci in enumerate(Ibright): cpu0 = time.clock() print('Fitting source', srci, '(%i of %i in blob %s)' % (numi+1, len(Ibright), self.name)) src = cat[srci] # Add this source's initial model back in. models.add(srci, self.tims) if self.bigblob: # Create super-local sub-sub-tims around this source # Make the subimages the same size as the modelMasks. #tbb0 = Time() mods = [mod[srci] for mod in models.models] srctims,modelMasks = _get_subimages(self.tims, mods, src) #print('Creating srctims:', Time()-tbb0) # We plots only the first & last three sources if self.plots_per_source and (numi < 3 or numi >= len(Ibright)-3): plt.clf() # Recompute coadds because of the subtract-all-and-readd shuffle coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, fill_holes=False) rgb = get_rgb(coimgs, self.bands) dimshow(rgb) #dimshow(self.rgb) ax = plt.axis() for tim in srctims: h,w = tim.shape tx,ty = [0,0,w,w,0], [0,h,h,0,0] rd = [tim.getWcs().pixelToPosition(xi,yi) for xi,yi in zip(tx,ty)] ra = [p.ra for p in rd] dec = [p.dec for p in rd] ok,x,y = self.blobwcs.radec2pixelxy(ra, dec) plt.plot(x, y, 'b-') ra,dec = tim.subwcs.pixelxy2radec(tx, ty) ok,x,y = self.blobwcs.radec2pixelxy(ra, dec) plt.plot(x, y, 'c-') plt.title('source %i of %i' % (numi, len(Ibright))) plt.axis(ax) self.ps.savefig() else: srctims = self.tims modelMasks = models.model_masks(srci, src) srctractor = self.tractor(srctims, [src]) #print('Setting modelMasks:', modelMasks) srctractor.setModelMasks(modelMasks) # if plots and False: # spmods,spnames = [],[] # spallmods,spallnames = [],[] # if numi == 0: # spallmods.append(list(tr.getModelImages())) # spallnames.append('Initial (all)') # spmods.append(list(srctractor.getModelImages())) # spnames.append('Initial') # First-round optimization #print('First-round initial log-prob:', srctractor.getLogProb()) srctractor.optimize_loop(**self.optargs) #print('First-round final log-prob:', srctractor.getLogProb()) # if plots and False: # spmods.append(list(srctractor.getModelImages())) # spnames.append('Fit') # spallmods.append(list(tr.getModelImages())) # spallnames.append('Fit (all)') # # if plots and False: # plt.figure(1, figsize=(8,6)) # plt.subplots_adjust(left=0.01, right=0.99, top=0.95, # bottom=0.01, hspace=0.1, wspace=0.05) # #plt.figure(2, figsize=(3,3)) # #plt.subplots_adjust(left=0.005, right=0.995, # # top=0.995,bottom=0.005) # #_plot_mods(tims, spmods, spnames, bands, None, None, bslc, # # blobw, blobh, ps, chi_plots=plots2) # plt.figure(2, figsize=(3,3.5)) # plt.subplots_adjust(left=0.005, right=0.995, # top=0.88, bottom=0.005) # plt.suptitle('Blob %i' % iblob) # tempims = [tim.getImage() for tim in tims] # # _plot_mods(list(srctractor.getImages()), spmods, spnames, # bands, None, None, bslc, blobw, blobh, ps, # chi_plots=plots2, rgb_plots=True, main_plot=False, # rgb_format=('spmods Blob %i, src %i: %%s' % # (iblob, i))) # _plot_mods(tims, spallmods, spallnames, bands, None, None, # bslc, blobw, blobh, ps, # chi_plots=plots2, rgb_plots=True, main_plot=False, # rgb_format=('spallmods Blob %i, src %i: %%s' % # (iblob, i))) # # models.restore_images(tims) # _plot_mods(tims, spallmods, spallnames, bands, None, None, # bslc, blobw, blobh, ps, # chi_plots=plots2, rgb_plots=True, main_plot=False, # rgb_format='Blob %i, src %i: %%s' % (iblob, i)) # for tim,im in zip(tims, tempims): # tim.data = im # Re-remove the final fit model for this source models.update_and_subtract(srci, src, self.tims) srctractor.setModelMasks(None) disable_galaxy_cache() #print('Fitting source took', Time()-tsrc) #print(src) cpu1 = time.clock() cputime[srci] += (cpu1 - cpu0) models.restore_images(self.tims) del models def _fit_fluxes(self, cat, tims, bands): cat.thawAllRecursive() for src in cat: src.freezeAllBut('brightness') for b in bands: for src in cat: src.getBrightness().freezeAllBut(b) # Images for this band btims = [tim for tim in tims if tim.band == b] btr = self.tractor(btims, cat) btr.optimize_forced_photometry(shared_params=False, wantims=False) cat.thawAllRecursive() def _plots(self, tr, title): plotmods = [] plotmodnames = [] plotmods.append(list(tr.getModelImages())) plotmodnames.append(title) for tim in tr.images: if hasattr(tim, 'resamp'): del tim.resamp _plot_mods(tr.images, plotmods, self.blobwcs, plotmodnames, self.bands, None, None, None, self.blobw, self.blobh, self.ps, chi_plots=False) for tim in tr.images: if hasattr(tim, 'resamp'): del tim.resamp def _plot_coadd(self, tims, wcs, model=None, resid=None): if resid is not None: mods = list(resid.getChiImages()) coimgs,cons = quick_coadds(tims, self.bands, wcs, images=mods, fill_holes=False) dimshow(get_rgb(coimgs,self.bands, **rgbkwargs_resid)) return mods = None if model is not None: mods = list(model.getModelImages()) coimgs,cons = quick_coadds(tims, self.bands, wcs, images=mods, fill_holes=False) dimshow(get_rgb(coimgs,self.bands)) def _initial_plots(self): print('Plotting blob image for blob', self.name) coimgs,cons = quick_coadds(self.tims, self.bands, self.blobwcs, fill_holes=False) self.rgb = get_rgb(coimgs, self.bands) plt.clf() dimshow(self.rgb) plt.title('Blob: %s' % self.name) self.ps.savefig() if self.plots_single: plt.figure(2) dimshow(self.rgb, ticks=False) plt.savefig('blob-%s-data.png' % (self.name)) plt.figure(1) ok,x0,y0 = self.blobwcs.radec2pixelxy( np.array([src.getPosition().ra for src in self.srcs]), np.array([src.getPosition().dec for src in self.srcs])) ax = plt.axis() plt.plot(x0-1, y0-1, 'r.') plt.axis(ax) plt.title('initial sources') self.ps.savefig() # plt.clf() # ccmap = dict(g='g', r='r', z='m') # for tim in tims: # chi = (tim.data * tim.inverr)[tim.inverr > 0] # plt.hist(chi.ravel(), range=(-5,10), bins=100, histtype='step', # color=ccmap[tim.band]) # plt.xlabel('signal/noise per pixel') # self.ps.savefig() def create_tims(self, timargs): # In order to make multiprocessing easier, the one_blob method # is passed all the ingredients to make local tractor Images # rather than the Images themselves. Here we build the # 'tims'. tims = [] for (img, inverr, twcs, wcs, pcal, sky, psf, name, sx0, sx1, sy0, sy1, band, sig1, modelMinval, imobj) in timargs: # Mask out inverr for pixels that are not within the blob. subwcs = wcs.get_subimage(int(sx0), int(sy0), int(sx1-sx0), int(sy1-sy0)) try: Yo,Xo,Yi,Xi,rims = resample_with_wcs(subwcs, self.blobwcs, [], 2) except OverlapError: continue if len(Yo) == 0: continue inverr2 = np.zeros_like(inverr) I = np.flatnonzero(self.blobmask[Yi,Xi]) inverr2[Yo[I],Xo[I]] = inverr[Yo[I],Xo[I]] inverr = inverr2 # If the subimage (blob) is small enough, instantiate a # constant PSF model in the center. if sy1-sy0 < 400 and sx1-sx0 < 400: subpsf = psf.constantPsfAt((sx0+sx1)/2., (sy0+sy1)/2.) else: # Otherwise, instantiate a (shifted) spatially-varying # PsfEx model. subpsf = psf.getShifted(sx0, sy0) tim = Image(data=img, inverr=inverr, wcs=twcs, psf=subpsf, photocal=pcal, sky=sky, name=name) tim.band = band tim.sig1 = sig1 tim.modelMinval = modelMinval tim.subwcs = subwcs tim.meta = imobj tim.psf_sigma = imobj.fwhm / 2.35 tim.dq = None tims.append(tim) return tims def _convert_ellipses(src): if isinstance(src, (DevGalaxy, ExpGalaxy)): #print('Converting ellipse for source', src) src.shape = src.shape.toEllipseE() #print('--->', src.shape) if isinstance(src, RexGalaxy): src.shape.freezeParams('e1', 'e2') elif isinstance(src, FixedCompositeGalaxy): src.shapeExp = src.shapeExp.toEllipseE() src.shapeDev = src.shapeDev.toEllipseE() src.fracDev = FracDev(src.fracDev.clipped()) def _compute_invvars(allderivs): ivs = [] for iparam,derivs in enumerate(allderivs): chisq = 0 for deriv,tim in derivs: h,w = tim.shape deriv.clipTo(w,h) ie = tim.getInvError() slc = deriv.getSlice(ie) chi = deriv.patch * ie[slc] chisq += (chi**2).sum() ivs.append(chisq) return ivs def _argsort_by_brightness(cat, bands): fluxes = [] for src in cat: # HACK -- here we just *sum* the nanomaggies in each band. Bogus! br = src.getBrightness() flux = sum([br.getFlux(band) for band in bands]) fluxes.append(flux) Ibright = np.argsort(-np.array(fluxes)) return Ibright def _compute_source_metrics(srcs, tims, bands, tr): # rchi2 quality-of-fit metric rchi2_num = np.zeros((len(srcs),len(bands)), np.float32) rchi2_den = np.zeros((len(srcs),len(bands)), np.float32) # fracflux degree-of-blending metric fracflux_num = np.zeros((len(srcs),len(bands)), np.float32) fracflux_den = np.zeros((len(srcs),len(bands)), np.float32) # fracin flux-inside-blob metric fracin_num = np.zeros((len(srcs),len(bands)), np.float32) fracin_den = np.zeros((len(srcs),len(bands)), np.float32) # fracmasked: fraction of masked pixels metric fracmasked_num = np.zeros((len(srcs),len(bands)), np.float32) fracmasked_den = np.zeros((len(srcs),len(bands)), np.float32) for iband,band in enumerate(bands): for tim in tims: if tim.band != band: continue mod = np.zeros(tim.getModelShape(), tr.modtype) srcmods = [None for src in srcs] counts = np.zeros(len(srcs)) pcal = tim.getPhotoCal() # For each source, compute its model and record its flux # in this image. Also compute the full model *mod*. for isrc,src in enumerate(srcs): patch = tr.getModelPatch(tim, src, minsb=tim.modelMinval) if patch is None or patch.patch is None: continue counts[isrc] = np.sum([np.abs(pcal.brightnessToCounts(b)) for b in src.getBrightnesses()]) if counts[isrc] == 0: continue H,W = mod.shape patch.clipTo(W,H) srcmods[isrc] = patch patch.addTo(mod) # Now compute metrics for each source for isrc,patch in enumerate(srcmods): if patch is None: continue if patch.patch is None: continue if counts[isrc] == 0: continue if np.sum(patch.patch**2) == 0: continue slc = patch.getSlice(mod) patch = patch.patch # print('fracflux: band', band, 'isrc', isrc, 'tim', tim.name) # print('src:', srcs[isrc]) # print('patch sum', np.sum(patch),'abs',np.sum(np.abs(patch))) # print('counts:', counts[isrc]) # print('mod slice sum', np.sum(mod[slc])) # print('mod[slc] - patch:', np.sum(mod[slc] - patch)) # (mod - patch) is flux from others # (mod - patch) / counts is normalized flux from others # We take that and weight it by this source's profile; # patch / counts is unit profile # But this takes the dot product between the profiles, # so we have to normalize appropriately, ie by # (patch**2)/counts**2; counts**2 drops out of the # denom. If you have an identical source with twice the flux, # this results in fracflux being 2.0 # fraction of this source's flux that is inside this patch. # This can be < 1 when the source is near an edge, or if the # source is a huge diffuse galaxy in a small patch. fin = np.abs(np.sum(patch) / counts[isrc]) # print('fin:', fin) # print('fracflux_num: fin *', # np.sum((mod[slc] - patch) * np.abs(patch)) / # np.sum(patch**2)) fracflux_num[isrc,iband] += (fin * np.sum((mod[slc] - patch) * np.abs(patch)) / np.sum(patch**2)) fracflux_den[isrc,iband] += fin fracmasked_num[isrc,iband] += ( np.sum((tim.getInvError()[slc] == 0) * np.abs(patch)) / np.abs(counts[isrc])) fracmasked_den[isrc,iband] += fin fracin_num[isrc,iband] += np.abs(np.sum(patch)) fracin_den[isrc,iband] += np.abs(counts[isrc]) tim.getSky().addTo(mod) chisq = ((tim.getImage() - mod) * tim.getInvError())**2 for isrc,patch in enumerate(srcmods): if patch is None or patch.patch is None: continue if counts[isrc] == 0: continue slc = patch.getSlice(mod) # We compute numerator and denom separately to handle # edge objects, where sum(patch.patch) < counts. # Also, to normalize by the number of images. (Being # on the edge of an image is like being in half an # image.) rchi2_num[isrc,iband] += (np.sum(chisq[slc] * patch.patch) / counts[isrc]) # If the source is not near an image edge, # sum(patch.patch) == counts[isrc]. rchi2_den[isrc,iband] += np.sum(patch.patch) / counts[isrc] #print('Fracflux_num:', fracflux_num) #print('Fracflux_den:', fracflux_den) fracflux = fracflux_num / fracflux_den rchi2 = rchi2_num / rchi2_den fracmasked = fracmasked_num / fracmasked_den # Eliminate NaNs (these happen when, eg, we have no coverage in one band but # sources detected in another band, hence denominator is zero) fracflux [ fracflux_den == 0] = 0. rchi2 [ rchi2_den == 0] = 0. fracmasked[fracmasked_den == 0] = 0. # fracin_{num,den} are in flux * nimages units tinyflux = 1e-9 fracin = fracin_num / np.maximum(tinyflux, fracin_den) return dict(fracin=fracin, fracflux=fracflux, rchisq=rchi2, fracmasked=fracmasked) def _initialize_models(src, rex): if isinstance(src, PointSource): ptsrc = src.copy() if rex: from legacypipe.survey import LogRadius simple = RexGalaxy(src.getPosition(), src.getBrightness(), LogRadius(-1.)).copy() #print('Created Rex:', simple) else: simple = SimpleGalaxy(src.getPosition(), src.getBrightness()).copy() # logr, ee1, ee2 shape = LegacyEllipseWithPriors(-1., 0., 0.) dev = DevGalaxy(src.getPosition(), src.getBrightness(), shape).copy() exp = ExpGalaxy(src.getPosition(), src.getBrightness(), shape).copy() comp = None oldmodel = 'ptsrc' elif isinstance(src, DevGalaxy): ptsrc = PointSource(src.getPosition(), src.getBrightness()).copy() simple = SimpleGalaxy(src.getPosition(), src.getBrightness()).copy() dev = src.copy() exp = ExpGalaxy(src.getPosition(), src.getBrightness(), src.getShape()).copy() comp = None oldmodel = 'dev' elif isinstance(src, ExpGalaxy): ptsrc = PointSource(src.getPosition(), src.getBrightness()).copy() simple = SimpleGalaxy(src.getPosition(), src.getBrightness()).copy() dev = DevGalaxy(src.getPosition(), src.getBrightness(), src.getShape()).copy() exp = src.copy() comp = None oldmodel = 'exp' elif isinstance(src, FixedCompositeGalaxy): ptsrc = PointSource(src.getPosition(), src.getBrightness()).copy() simple = SimpleGalaxy(src.getPosition(), src.getBrightness()).copy() frac = src.fracDev.clipped() if frac > 0: shape = src.shapeDev else: shape = src.shapeExp dev = DevGalaxy(src.getPosition(), src.getBrightness(), shape).copy() if frac < 1: shape = src.shapeExp else: shape = src.shapeDev exp = ExpGalaxy(src.getPosition(), src.getBrightness(), shape).copy() comp = src.copy() oldmodel = 'comp' return oldmodel, ptsrc, simple, dev, exp, comp def _get_subimages(tims, mods, src): subtims = [] modelMasks = [] #print('Big blob: trimming:') for tim,mod in zip(tims, mods): if mod is None: continue mh,mw = mod.shape if mh == 0 or mw == 0: continue # for modelMasks d = { src: ModelMask(0, 0, mw, mh) } modelMasks.append(d) x0,y0 = mod.x0 , mod.y0 x1,y1 = x0 + mw, y0 + mh subtim = _get_subtim(tim, x0, x1, y0, y1) if subtim.shape != (mh,mw): print('Subtim was not the shape expected:', subtim.shape, 'image shape', tim.getImage().shape, 'slice y', y0,y1, 'x', x0,x1, 'mod shape', mh,mw) subtims.append(subtim) return subtims, modelMasks def _get_subtim(tim, x0, x1, y0, y1): slc = slice(y0,y1), slice(x0, x1) subimg = tim.getImage()[slc] subpsf = tim.psf.constantPsfAt((x0+x1)/2., (y0+y1)/2.) subtim = Image(data=subimg, inverr=tim.getInvError()[slc], wcs=tim.wcs.shifted(x0, y0), psf=subpsf, photocal=tim.getPhotoCal(), sky=tim.sky.shifted(x0, y0), name=tim.name) sh,sw = subtim.shape subtim.subwcs = tim.subwcs.get_subimage(x0, y0, sw, sh) subtim.band = tim.band subtim.sig1 = tim.sig1 subtim.modelMinval = tim.modelMinval subtim.x0 = x0 subtim.y0 = y0 subtim.meta = tim.meta subtim.psf_sigma = tim.psf_sigma if tim.dq is not None: subtim.dq = tim.dq[slc] else: subtim.dq = None return subtim
[docs]class SourceModels(object): ''' This class maintains a list of the model patches for a set of sources in a set of images. ''' def __init__(self): self.filledModelMasks = True def save_images(self, tims): self.orig_images = [tim.getImage() for tim in tims] for tim,img in zip(tims, self.orig_images): tim.data = img.copy() def restore_images(self, tims): for tim,img in zip(tims, self.orig_images): tim.data = img
[docs] def create(self, tims, srcs, subtract=False): ''' Note that this modifies the *tims* if subtract=True. ''' self.models = [] for tim in tims: mods = [] sh = tim.shape ie = tim.getInvError() for src in srcs: mod = src.getModelPatch(tim) if mod is not None and mod.patch is not None: if not np.all(np.isfinite(mod.patch)): print('Non-finite mod patch') print('source:', src) print('tim:', tim) print('PSF:', tim.getPsf()) assert(np.all(np.isfinite(mod.patch))) mod = _clip_model_to_blob(mod, sh, ie) if subtract and mod is not None: mod.addTo(tim.getImage(), scale=-1) mods.append(mod) self.models.append(mods)
[docs] def add(self, i, tims): ''' Adds the models for source *i* back into the tims. ''' for tim,mods in zip(tims, self.models): mod = mods[i] if mod is not None: mod.addTo(tim.getImage())
def update_and_subtract(self, i, src, tims): for tim,mods in zip(tims, self.models): #mod = srctractor.getModelPatch(tim, src) if src is None: mod = None else: mod = src.getModelPatch(tim) if mod is not None: mod.addTo(tim.getImage(), scale=-1) mods[i] = mod def model_masks(self, i, src): modelMasks = [] for mods in self.models: d = dict() modelMasks.append(d) mod = mods[i] if mod is not None: if self.filledModelMasks: mh,mw = mod.shape d[src] = ModelMask(mod.x0, mod.y0, mw, mh) else: d[src] = ModelMask(mod.x0, mod.y0, mod.patch != 0) return modelMasks
def remap_modelmask(modelMasks, oldsrc, newsrc): mm = [] for mim in modelMasks: d = dict() mm.append(d) try: d[newsrc] = mim[oldsrc] except KeyError: pass return mm def _clip_model_to_blob(mod, sh, ie): ''' mod: Patch sh: tim shape ie: tim invError Returns: new Patch ''' mslc,islc = mod.getSlices(sh) sy,sx = mslc patch = mod.patch[mslc] * (ie[islc]>0) if patch.shape == (0,0): return None mod = Patch(mod.x0 + sx.start, mod.y0 + sy.start, patch) # Check mh,mw = mod.shape assert(mod.x0 >= 0) assert(mod.y0 >= 0) ph,pw = sh assert(mod.x0 + mw <= pw) assert(mod.y0 + mh <= ph) return mod def _select_model(chisqs, nparams, galaxy_margin, rex): ''' Returns keepmod ''' keepmod = 'none' # This is our "detection threshold": 5-sigma in # *parameter-penalized* units; ie, ~5.2-sigma for point sources cut = 5.**2 # Take the best of all models computed diff = max([chisqs[name] - nparams[name] for name in chisqs.keys() if name != 'none'] + [-1]) if diff < cut: return keepmod # We're going to keep this source! if rex: simname = 'rex' else: simname = 'simple' if not simname in chisqs: # bright stars / reference stars: we don't test the simple model. return 'ptsrc' # Now choose between point source and simple model (SIMP/REX) if chisqs.get('ptsrc',0)-nparams['ptsrc'] > chisqs.get(simname,0)-nparams[simname]: #print('Keeping source; PTSRC is better than SIMPLE') keepmod = 'ptsrc' else: #print('Keeping source; SIMPLE is better than PTSRC') #print('REX is better fit. Radius', simplemod.shape.re) keepmod = simname # For REX, we also demand a fractionally better fit if simname == 'rex': dchisq_psf = chisqs.get('ptsrc',0) dchisq_rex = chisqs.get('rex',0) if dchisq_psf > 0 and (dchisq_rex - dchisq_psf) < (0.01 * dchisq_psf): keepmod = 'ptsrc' if not ('exp' in chisqs or 'dev' in chisqs): return keepmod # This is our "upgrade" threshold: how much better a galaxy # fit has to be versus ptsrc, and comp versus galaxy. cut = galaxy_margin # This is the "fractional" upgrade threshold for ptsrc/simple->dev/exp: # 1% of ptsrc vs nothing fcut = 0.01 * chisqs.get('ptsrc', 0) #print('Cut: max of', cut, 'and', fcut, ' (fraction of chisq_psf=%.1f)' # % chisqs['ptsrc']) cut = max(cut, fcut) expdiff = chisqs.get('exp', 0) - chisqs[keepmod] devdiff = chisqs.get('dev', 0) - chisqs[keepmod] #print('EXP vs', keepmod, ':', expdiff) #print('DEV vs', keepmod, ':', devdiff) if not (expdiff > cut or devdiff > cut): #print('Keeping', keepmod) return keepmod if expdiff > devdiff: #print('Upgrading from PTSRC to EXP: diff', expdiff) keepmod = 'exp' else: #print('Upgrading from PTSRC to DEV: diff', expdiff) keepmod = 'dev' if not 'comp' in chisqs: return keepmod diff = chisqs['comp'] - chisqs[keepmod] #print('Comparing', keepmod, 'to comp. cut:', cut, 'comp:', diff) if diff < cut: return keepmod #print('Upgrading from dev/exp to composite.') keepmod = 'comp' return keepmod def _chisq_improvement(src, chisqs, chisqs_none): ''' chisqs, chisqs_none: dict of band->chisq ''' bright = src.getBrightness() bands = chisqs.keys() fluxes = dict([(b, bright.getFlux(b)) for b in bands]) dchisq = 0. for b in bands: flux = fluxes[b] if flux == 0: continue # this will be positive for an improved model d = chisqs_none[b] - chisqs[b] if flux > 0: dchisq += d else: dchisq -= np.abs(d) return dchisq def _per_band_chisqs(tractor, bands): chisqs = dict([(b,0) for b in bands]) for i,img in enumerate(tractor.images): chi = tractor.getChiImage(img=img) chisqs[img.band] = chisqs[img.band] + (chi ** 2).sum() return chisqs def _limit_galaxy_stamp_size(src, tim, maxhalf=128): from tractor.galaxy import ProfileGalaxy if isinstance(src, ProfileGalaxy): px,py = tim.wcs.positionToPixel(src.getPosition()) h = src._getUnitFluxPatchSize(tim, px, py, tim.modelMinval) if h > maxhalf: #print('halfsize', h, 'for', src, '-> setting to', maxhalf) src.halfsize = maxhalf