
#cython: boundscheck=False
#cython: wraparound=False

from   starbase import *
from   numpy    import *
set_printoptions(threshold=nan)
from   scipy.integrate import dblquad

# x is kernal wavelength + line wavelength
# y is kernal value * line flux
# u is lower boundary of wavelength bins into which we distribute flux
#print "BEGIN ----------------------------------------------------"

def xsmooth(spec, smc) :
   al=len(spec)

   if smc % 2 == 0:
      half=int(smc/2-1)
      even=1
   else:
      half=int(smc/2)
      even=0

   smspec=zeros(al, double)

   for i in range(half+1,al-half-1):
      tot = spec[-half+i:i+half+1].sum()
      inc = smc-1

      if even :
           tot=tot+(spec[-half+i-1]+spec[i+half+1])/2.
           inc=inc+1

      smspec[i]=tot/double(inc)

   return smspec

def xsmooth_mask(spec_mask, smc) :
   al=len(spec_mask)

   if smc % 2 == 0:
      half=int(smc/2-1)
      even=1
   else:
      half=int(smc/2)
      even=0

   smspec_mask=zeros(al, double)

   for i in range(half+1,al-half-1):
      incm = 0
      totm = 0

      for j in range(-half+i,i+half+1):
        if spec_mask[j] != 0:
           incm=incm+1
           totm=totm+spec_mask[j]

      if even :
         if spec_mask[-half+i-1] != 0 and spec_mask[i+half+1] != 0.:
            totm=totm+(spec_mask[-half+i-1]+spec_mask[i+half+1])/2.
            incm=incm+1
            
      if incm > half:
         smspec_mask[i]=totm/double(incm)

   return smspec_mask

def Calc(transmiss, instmode, objspec, skybkgd,
         redshift,
         bandnorm, bandfilter, srcextent,
	 aptype, apdim1, apdim2,
         seeing, moonage, airmass, exptime, smscale,
         transcut, rand):

   seterr(under="ignore")
#   print "HERE ----------------------------------------------------"
#   print "exptime", exptime
#   print "airmass", airmass

   # put the parameter here which should come from the interface, or are not needed for optical

#   airmass = 1.6
   

   def intfunc(x,y):
           twosigsq=2.*((seeing/2.3548)**2 + (srcextent/1.1774)**2)
           v=(0.318309886/twosigsq)*exp(-(x**2+y**2)/twosigsq)
           return v

   def llfunc(x):
            if aptype == "Square":
              v=-(apdim1)/2.
              return v
           
            if aptype == "Rectangular":
              v=-(apdim2)/2.
              return v

   def ulfunc(x):
           if aptype == "Square":
              v=(apdim1)/2.
              return v
           
           if aptype == "Rectangular":
              v=(apdim2)/2.
              return v

# hc in units of erg-micron
   hc=1.986448e-12 
# fillowing in cgs
   h=6.626076e-27
   c=2.997925e10
   k=1.380658e-16
   arcpster=2.350443e-11

   # Read in spectral template file and effective areas
   #
   d = Starbase(objspec)
  
   # effective area wavelengths in microns, effective area in m(2),
   # and instrument parameters
   #
   e = Starbase(instmode)

   angsperpix=e.Angsperpix[0]	# Angstrom per pixel
   pixscale=e.Pixscale[0] 		# Arcsec per pixel
   wavelow=e.Wavelow 		# Lower band wavelength
   wavehigh=e.Wavehigh 		# Upper band wavelength


   # read in atmospheric transmission data
   #
   f = Starbase(transmiss)

   waveatm=f.Wave
   atmtrans=f.Trans
   
   micperpix=angsperpix*1.e-4
   wavelm=arange(wavelow,wavehigh,micperpix, double)
   wavehm=wavelm+angsperpix*1.e-4
   wave=(wavelm+wavehm)/2.
#   print "wave", wave, "wave"
   readnoise=double(e.Readnoise[0])
   darkcur=double(e.Darkcur[0])
# Generate resolution kernal for slitwidth and grating configuration

# Now calculate number of pixels per slit width
   npix=apdim1/pixscale
# Now calculate slit width in microns
   swmic=npix*micperpix
#   print "npix", "swmic", npix, swmic


# Now create array with wavelengths 


   waveraw = d.Wave	 		# waveraw is the array containing file wavelengths
   fluxraw = d.Flux 			# fluxraw is the array containing file fluxes 

# Wavecon is the number of microns per spectrum file wavelength units
# wavesrc is the array containing wavelengths converted to microns for the spectrum file
   wavein=waveraw*d.Wavecon
# Scale is the constant required to convert the file fluxes (F-lambda) to
# erg m(-2) s(-1) micron(-1)
# If scale is negative, fluxes are in F_nu
   fscale=d.Scale
   if fscale > 0.0:
      flux=fluxraw*fscale
   if fscale <=0.0:
      fscale=-1.0*fscale
      flux=(fscale/6.626e-8)*(fluxraw/wavein)

# shift wavesrc to appropriate wavelengths
   zsrc=d.z
   zobs=redshift
   wavesrc=((1.+zobs)/(1+zsrc))*wavein
   
# photeng is the energy per photon at the output spectrum wavelengths
   photeng=hc/wave
#   print "photeng[1000]", photeng[1000], "wave[1000]", wave[1000]
# rebin atmospheric transmission into instrument pixels
   atmint=interp(wave, waveatm, atmtrans, 0., 0.)


   extspec = "Bkgd/extinct.mag.dat"
   ae = Starbase(extspec)
   waveatm2 = ae.Wave*ae.Wavecon
   extcoeff = ae.extinction
#   print "waveatm2[10]", waveatm2[10]
#   print "extcoeff[10]", extcoeff[10]
#   print "airmass here", airmass


   atmtrans2 = 10**(-.4*airmass*extcoeff)
#   print "atmtrans2[10]", atmtrans2[10]
   atmint2=interp(wave, waveatm2, atmtrans2, 0., 0.)

#   print "atmint2[1000]", atmint2[1000]

#  Get A & B absorption bands
   BBand = Starbase('Bkgd/mcmath_comp_skyabs.tab')
   wavebband = BBand.Wave*BBand.Wavecon
   ThruBBand = BBand.Thru*BBand.Thrucon
   ThruBBandint = interp(wave, wavebband, ThruBBand, 0., 0.)

#  atmospheric transmission is product of atm and bband

   #Sumatmint = atmint*ThruBBandint
   Sumatmint = atmint2*ThruBBandint

#   print "Sumatmint[1000]", Sumatmint[1000]
   #print "Sumatmint2[1000]", Sumatmint2[1000]

# rebin flux into instrument pixels
# photint is the flux in ph m(-2) s(-1) micron(-1) OUTSIDE the atmosphere

   fluxint=interp(wave, wavesrc, flux, 0., 0.)
   photint=fluxint/photeng

# Start out by calculating the source counts per subexposure

   # wavelengths in microns and effective areas in m(2)
   #
   waveea=e.Wave
   effarea=e.EffArea

# Need to apply effective areas
# First interpolate effective areas to instrument wavelengths
   effareaint=interp(wave, waveea, effarea, 0., 0.)

#  Also need to get filter transmission and aply to input spectrum
   filter = Starbase(bandfilter)
   wavefilt = filter.Wave*filter.Wavecon
   thrufilt = filter.Thruput*filter.Thrucon
   thrufiltint = interp(wave, wavefilt, thrufilt, 0., 0.) 
#  Get number of photons through the filter
   photintfilt = photint*thrufiltint
#  Get effective area including filter
#   print 'effareaint[2000], thrufiltint[2000], wave[2000]',effareaint[2000], thrufiltint[2000], wave[2000]
   effareafiltint = effareaint*thrufiltint
   sumareafilt = sum(effareafiltint)
#   print 'sumareafilt', sumareafilt
   effnormfilt = effareafiltint/sumareafilt
   pintfilt = sum(effnormfilt*photintfilt*wave)
   ABmagfilt = -2.5*log10(pintfilt*6.626e-8) + 8.9
   


# Normalize input spectrum to have the requested magnitude
   sumarea=sum(effareaint)
   effnorm=effareaint/sumarea
   pint=sum(effnorm*photint*wave)
# pint is the photon flux in ph m(-2) s(-1) micron(-1)
   if ( pint != 0.0 ) :
        ABmag=-2.5*log10(pint*6.626e-8)+8.9
   else : 
        ABmag = 0

#   magdiff=ABmag-bandnorm
   magdiff=ABmagfilt-bandnorm
   magfac=10.**(0.4*magdiff)

#   print 'ABmag, ABmagfilt', ABmag, ABmagfilt

#   photpix=photint*micperpix
   photpix=photint*micperpix*magfac

# calculate and apply aperture correction factor for object counts
   atype=aptype
   apd1=apdim1
   apd2=apdim2

   if atype == "Round":
        sig=((seeing/2.3548)**2 + (srcextent/1.1774)**2)**0.5      
        rad=apd1/2.
        apcor=1.-exp(-0.5*((rad/sig)**2))
#        print rad, sig, apcor
   if atype == "Square" or atype == "Rectangular":       
        apt=dblquad(intfunc,-apd1/2.,apd1/2.,llfunc,ulfunc)
        apcor=apt[0]
#        print apcor, apd1
      
# Multiply by effective area and subexposure time and pixels per resolution element
# to get photons per pixel per subexposure time at telescope
   tphotfin=photpix*effareaint*exptime*apcor*Sumatmint
#   print 'apcor', apcor
#   print "exptime", exptime
#   print "seeing", seeing


# Apply resolution kernal to the spectrum
   temp=ones(npix)
   norm=sum(temp)
   kernal=temp/norm
#   print kernal
   photfin=convolve(tphotfin, kernal,"same")


# Read in sky background and apply

# Calculate object aperture type and size
 
   if atype == "Square":
      aparea=apd1**2
   if atype == "Round":
      aparea=0.7854*apd1**2
   if atype == "Rectangular":
      aparea=apd1*apd2

   f = Starbase(skybkgd, types = {"Wave":double, "Flux":double})
# Convert bkgd wavelengths to microns
   wavebcon=double(f.Wavecon[0])
# Convert OH, etc. line fluxes to ph m(-2) s(-1) arcsec(-1)
   fluxbcon=double(f.Fluxcon[0])
   eps=double(f.BBodyEmis[0])
   T=double(f.BBodyTemp[0])
   fluxintline=double(f.Fluxintline[0])
   waveb = array(f[:].Wave)*wavebcon
   fluxb = array(f[:].Flux)*fluxbcon
#   print "fluxb[1000]",fluxb[1000],"waveb[1000]",waveb[1000]

# waveb and fluxb are line wavelengths and fluxes; they must be convolved
# with the instrument resolution profile and distributed in wavelength bins
   bk=zeros(len(wavelm))
   i=0
   for wv in waveb:
      if waveb[i] >= wavelow and waveb[i] <= wavehigh:
         bk = interp(wave, waveb, fluxb, 0, 0)
      i=i+1

   indsk=where(bk > 0.005)
#   print "bk[1000], bk[1001], wavelm[1000], wavelm[1001]", bk[1000], bk[1001], wavelm[1000], wavelm[1001]
# There are two continuum sources of background, the "Maihara" interline
# background and thermal emission

# Blackbody component
# lambda(T) = eps*2c/(lambda*4*(exp(hc/lambda*kT)-1) ph s-1 cm-2 ster-1 cm-1
# or equivalently ph s-1 m-2 ster-1 um-1
#   print k, T
   term=expm1((h*c)/(wave*1.e-4*k*T))
   therm1=(eps*2.*c)/((wave*1.e-4)**4*term)

   bbody=therm1*arcpster
   contbkgd=bbody+fluxintline*atmint 
 
# Continuum plus thermal background
   tb1=contbkgd*effareaint*micperpix*exptime*aparea
#   print "tb1[1000]", tb1[1000]
# OH and 02 line background
#  This is the dark sky in optical
   tb2=bk*effareaint*exptime*aparea*atmint*airmass
#   print "tb2[1000]", tb2[1000]
#   print "airmass", airmass

# compute moon
   m = Starbase("Bkgd/moon.photons.tab")
   moonwave = m.Wave*m.Wavecon
   if (moonage == 0):
      moon = m.moon0 - m.moon0
   elif (moonage == 3):
      moon = m.moon3 - m.moon0
#      print "moonage", moonage
#      print moon[1000]
#      print moonwave[1000]
#      print m.moon3[1000]
#      print m.moon0[1000]
   elif (moonage == 7):
      moon = m.moon7 - m.moon0
   elif (moonage == 10):
      moon = m.moon10 - m.moon0
   elif (moonage == 14):
      moon = m.moon14 - m.moon0
   else:
      print "moonage must be 0,3,7,10,14.  Assume DARK"
      moon = m.moon0 - m.moon0

#   print moon[1000]
   moonint = interp(wave, moonwave, moon, 0., 0.)

#   print exptime
#   print moonint[1000]
#   print aparea
#   print atmint[1000]
#   print effareaint[1000]
   
   moonphotons = moonint*exptime*aparea*atmint*effareaint
#   print "moonphotons[1000]", moonphotons[1000]


   # Dark current
   #
   tb3=darkcur*exptime*(aparea/(pixscale**2))
#   print "tb3", tb3

   photbfin=tb2+tb3+moonphotons

   # Calculate rms noise in spectrum
   #
   timerat=exptime/exptime

   # photbfin and photfin are counts per subexposure time per pixel
   # backvar and specvar are the variances per subexposure interval per pixel
   #
   backvar=(photbfin+readnoise**2)
   specvar=(photfin + photbfin + readnoise**2)

   # specn is the noise per subexposure interval in the observation aperture
   #
   specn=(specvar)**0.5
# cspecn is the noise per exposure time pre pixel, noise increases at t**0.5
   cspecn=specn*(timerat**(0.5))
# cphotfin gives the counts per exposure time per pixel
   cphotfin=photfin*timerat
   
   if ( rand == "yes" ) :
       rspecn=random.normal(0., cspecn)
   else :
       rspecn=cspecn
   
   specout=cphotfin+rspecn
#   print "specout[1000]", specout[1000]
#   print "cphotfin[1000]", cphotfin[1000]
#   print "rspecn[1000]", rspecn[1000]
#   print "wave[1000]", wave[1000]
#   print specout
# respix is the number of pixels per resolution element, already defined as npix
   respix=npix
#   print "respix", respix, "npix", npix
# sn is signal to noise
   sn=((respix)**0.5)*cphotfin/cspecn

# smooth masked array
   specout_mask=ma.masked_where(bk > 0.005, specout) 

   if smscale == 1 :
       smspec=specout
       smspec_mask=specout_mask
   else :
       smspec      = xsmooth(     specout     , smscale)
       smspec_temp = xsmooth_mask(specout_mask.filled(0), smscale)

       smspec_mask=ma.masked_equal(smspec_temp, 0.)

#   print ma.max(smspec_mask), ma.min(smspec_mask), median(smspec_mask)
   
# smspec gives the number of photons/pixel in the exposure time
# to recover F-lambda from smspec, multiply by hc/lambda = photeng
 # divide by effareaint, divide by micperpix, divide by exposure time,
# divide by apcor.  Units will be ergs m(-2) s(-1) micron(-1).  To convert
# to ergs cm(-2) s(-1) Angstrom(-1) divide by 1e8.
# smspec_mask has the sky lines masked

   atmfix=convolve(atmint, kernal, "same")
   atmfix[where(atmfix == 0.0)]=1.e-10
#   print "atmfix[1000], smspec[1000]", atmfix[1000], smspec[1000]

   fluxout=smspec*photeng/(micperpix*effareaint*exptime*apcor*atmfix*1.e8)
   fluxout=ma.masked_where(atmfix <=transcut, fluxout)
   fluxout_mask=smspec_mask*photeng/(micperpix*effareaint*exptime*apcor*atmfix*1.e8)
   fluxout_mask=ma.masked_where(atmfix <=transcut, fluxout_mask)
#   print "fluxout[1000]", fluxout[1000]

# set plot limits
   p1maxy=1.25*max(sn)  
   p1miny=min(sn) 

   maxx=wavehigh
   minx=wavelow

#   if max(sn) > 1:
#      maxphot=max(smspec[where(sn > 1)])
#   else:
   maxphot=max(smspec)
   p2maxy=1.25*maxphot
#   p2miny=-(maxphot-2.0*ma.median(smspec_mask))
#   p2maxy = 4000
   p2miny = 0.0
#   print 'maxphot', maxphot, 'p2maxy', p2maxy

#   if max(sn) > 1:
#      maxflux=fluxout[where(sn > 1)].max() 
#   else:
   maxflux=fluxout.max()

   p3maxy=maxflux*1.25
#   p3miny=-(maxflux-2.0*ma.median(fluxout_mask))
   p3miny = 0.0
#   print 'maxflux', maxflux, 'p3maxy', p3maxy, 'p3miny', p3miny
#   print 'smsspec[1000]', smspec[1000]
#   print atmfix
#   print "wave[1000]",wave[1000]

#   print pixscale, angsperpix

   return (minx, maxx, p1miny, p1maxy, p2miny, p2maxy, p3miny, p3maxy
	 , ABmag, wave, sn, cphotfin, rspecn, smspec_mask
	 , fluxout, fluxout_mask, pixscale, angsperpix)
#, cphotfin, rspecn)
