Skip to content

Commit

Permalink
Formatting and misc. cleanup to match openet-core
Browse files Browse the repository at this point in the history
  • Loading branch information
cgmorton committed Mar 12, 2024
1 parent e2498d2 commit e5d364d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 52 deletions.
3 changes: 1 addition & 2 deletions openet/ptjpl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# from importlib import metadata

from .collection import Collection
from .image import Image
from . import interpolate

MODEL_NAME = 'PTJPL'

# from importlib import metadata
# # __version__ = metadata.version(__package__ or __name__)
# __version__ = metadata.version(__package__.replace('.', '-') or __name__.replace('.', '-'))
# # __version__ = metadata.version('openet-ptjpl')
14 changes: 3 additions & 11 deletions openet/ptjpl/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,8 @@ def interpolate(
if type(self.model_args['et_reference_source']) is str:
# Assume a string source is a single image collection ID
# not a list of collection IDs or ee.ImageCollection
daily_et_ref_coll_id = self.model_args['et_reference_source']
daily_et_ref_coll = (
ee.ImageCollection(daily_et_ref_coll_id)
ee.ImageCollection(self.model_args['et_reference_source'])
.filterDate(start_date, end_date)
.select([self.model_args['et_reference_band']], ['et_reference'])
)
Expand Down Expand Up @@ -682,12 +681,7 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
"""
if ('et' in variables) or ('et_fraction' in variables):
et_img = (
daily_coll
.filterDate(agg_start_date, agg_end_date)
.select(['et'])
.sum()
)
et_img = daily_coll.filterDate(agg_start_date, agg_end_date).select(['et']).sum()

if ('et_reference' in variables) or ('et_fraction' in variables):
# Get the reference ET image from the reference ET collection,
Expand All @@ -711,9 +705,7 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
if 'et_reference' in variables:
image_list.append(et_reference_img.float())
if 'et_fraction' in variables:
image_list.append(
et_img.divide(et_reference_img).rename(['et_fraction']).float()
)
image_list.append(et_img.divide(et_reference_img).rename(['et_fraction']).float())
if 'ndvi' in variables:
ndvi_img = (
daily_coll.filterDate(agg_start_date, agg_end_date)
Expand Down
49 changes: 10 additions & 39 deletions openet/ptjpl/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,12 @@ def interpolate_prep(img):
img.select(interp_vars).double().multiply(ee.Number(img.get('scale_factor')))
.addBands([mask_img, time_img])
.set({'system:time_start': ee.Number(img.get('system:time_start'))})
# .set({'image_id': ee.String(img.get('system:index'))})
)
# .set({'image_id': ee.String(img.get('system:index'))})

# Filter scene collection to the interpolation range
# This may not be needed since scene_coll was built to this range
scene_coll = (
scene_coll
.filterDate(interp_start_date, interp_end_date)
.map(interpolate_prep)
)
scene_coll = scene_coll.filterDate(interp_start_date, interp_end_date).map(interpolate_prep)

# For count, compute the composite/mosaic image for the mask band only
if 'count' in variables:
Expand Down Expand Up @@ -343,20 +339,11 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
"""
if ('et' in variables) or ('et_fraction' in variables):
et_img = (
daily_coll
.filterDate(agg_start_date, agg_end_date)
.select(['et'])
.sum()
)
et_img = daily_coll.filterDate(agg_start_date, agg_end_date).select(['et']).sum()

if ('et_reference' in variables) or ('et_fraction' in variables):
et_reference_img = (
daily_et_ref_coll
.filterDate(agg_start_date, agg_end_date)
.select(['et_reference'])
.sum()
)
et_reference_img = daily_et_ref_coll.filterDate(agg_start_date, agg_end_date).sum()

if et_reference_resample and (et_reference_resample in ['bilinear', 'bicubic']):
et_reference_img = (
et_reference_img
Expand All @@ -371,9 +358,7 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
image_list.append(et_reference_img.float())
if 'et_fraction' in variables:
# Compute average et fraction over the aggregation period
image_list.append(
et_img.divide(et_reference_img).rename(['et_fraction']).float()
)
image_list.append(et_img.divide(et_reference_img).rename(['et_fraction']).float())
if 'ndvi' in variables:
# Compute average ndvi over the aggregation period
ndvi_img = (
Expand Down Expand Up @@ -624,15 +609,13 @@ def from_scene_et_actual(

# Assume a string source is a single image collection ID
# not a list of collection IDs or ee.ImageCollection
daily_et_ref_coll_id = model_args['et_reference_source']
daily_et_ref_coll = (
ee.ImageCollection(daily_et_ref_coll_id)
ee.ImageCollection(model_args['et_reference_source'])
.filterDate(start_date, end_date)
.select([model_args['et_reference_band']], ['et_reference'])
)

# Scale reference ET images (if necessary)
# CGM - Resampling is not working correctly so not including for now
if et_reference_factor and (et_reference_factor != 1):
def et_reference_adjust(input_img):
return (
Expand All @@ -642,7 +625,6 @@ def et_reference_adjust(input_img):
)
daily_et_ref_coll = daily_et_ref_coll.map(et_reference_adjust)


# Target collection needs to be filtered to the same date range as the
# scene collection in order to normalize the scenes.
# It will be filtered again to the start/end when it is sent into
Expand Down Expand Up @@ -766,21 +748,12 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
"""
if ('et' in variables) or ('et_fraction' in variables):
et_img = (
daily_coll
.filterDate(agg_start_date, agg_end_date)
.select(['et'])
.sum()
)
et_img = daily_coll.filterDate(agg_start_date, agg_end_date).select(['et']).sum()

if ('et_reference' in variables) or ('et_fraction' in variables):
# Get the reference ET image from the reference ET collection,
# not the interpolated collection
et_reference_img = (
daily_et_ref_coll
.filterDate(agg_start_date, agg_end_date)
.sum()
)
et_reference_img = daily_et_ref_coll.filterDate(agg_start_date, agg_end_date).sum()
if et_reference_resample and (et_reference_resample in ['bilinear', 'bicubic']):
et_reference_img = (
et_reference_img
Expand All @@ -795,9 +768,7 @@ def aggregate_image(agg_start_date, agg_end_date, date_format):
image_list.append(et_reference_img.float())
if 'et_fraction' in variables:
# Compute average et fraction over the aggregation period
image_list.append(
et_img.divide(et_reference_img).rename(['et_fraction']).float()
)
image_list.append(et_img.divide(et_reference_img).rename(['et_fraction']).float())
# if 'ndvi' in variables:
# # Compute average ndvi over the aggregation period
# ndvi_img = (
Expand Down

0 comments on commit e5d364d

Please sign in to comment.