#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024 HyperGas developers
#
# This file is part of hypergas.
#
# hypergas is a library to retrieve trace gases from hyperspectral satellite data
"""Reduce the radom noise."""
import logging
import numpy as np
import xarray as xr
from scipy.ndimage import label
from scipy.stats.mstats import trimmed_mean
from skimage.restoration import (calibrate_denoiser, denoise_invariant,
denoise_tv_chambolle)
# set the logger level
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s: %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
LOG = logging.getLogger(__name__)
[docs]
class Denoise():
"""The Denoise Class."""
def __init__(self, scene, varname, method='calibrated_tv_filter', weight=None):
"""Initialize Denoise.
Parameters
----------
scn : :class:`~satpy.Scene`
Satpy Scene which includes a variable to be denoised.
varname : str
The variable to be denoised.
method : str
The denoising method: "tv_filter" and "calibrated_tv_filter" (default).
weight : int
The weight for denoise_tv_chambolle.
It would be neglected if method is "calibrated_tv_filter".
If the weight is ``None`` (default) and ``method`` is “tv_filter”, the denoise_tv_chambolle will use the default value (0.1) which is too low for hyperspectral noisy gas field.
"""
self.data = scene[varname]
self.segmentation = scene['segmentation']
self.weight = weight
self.method = method
def _create_mask_from_quantiles(self, image, lower_quantile=0.01, upper_quantile=0.99, min_cluster_size=10):
"""
Create a mask based on quantile values to exclude isolated extreme values.
Parameters
----------
image : :class:`~xarray.DataArray`
Input 2D image.
lower_quantile : float
Lower quantile threshold (0-1).
upper_quantile : float
Upper quantile threshold (0-1).
min_cluster_size : int
Minimum size of connected clusters to retain (in pixels).
Returns
-------
mask : :class:`~xarray.DataArray`
Masked 2D image with isolated outliers removed.
"""
# Compute quantile thresholds directly in xarray
lower_thresh = image.quantile(lower_quantile)
upper_thresh = image.quantile(upper_quantile)
# Identify potential outliers using xarray operations
outliers = (image < lower_thresh) | (image > upper_thresh)
# Label connected components in the outlier mask
labeled_clusters, num_features = label(outliers.values) # Convert only for labeling step
if num_features == 0:
return image # No outliers detected, return original image
# Compute sizes of all clusters using np.bincount
cluster_sizes = np.bincount(labeled_clusters.ravel())
# Identify small clusters efficiently
small_clusters = np.isin(labeled_clusters, np.where(cluster_sizes < min_cluster_size)[0])
# Convert the small cluster mask back to xarray
mask = xr.DataArray(~small_clusters, coords=image.coords, dims=image.dims)
# Apply mask using xarray.where
return image.where(mask)
def _copy_attrs(self, res):
"""Copy data attributes to the denoised field"""
# create DataArray
res = xr.DataArray(res, coords=self.data.squeeze().coords, dims=self.data.squeeze().dims)
# copy attrs
res = res.rename(self.data.name+'_denoise')
res.attrs = self.data.attrs
description = f'denoised by the {self.method} method with weight={self.weight}'
if 'description' in res.attrs:
res.attrs['description'] = f"{res.attrs['description']} ({description})"
return res
[docs]
def tv_filter(self):
"""Call TV filter"""
noisy = self.data.squeeze().where(self.segmentation == self.seg_id)
trim_mean = trimmed_mean(noisy.stack(z=('y', 'x')).dropna('z'), (1e-3, 1e-3))
res = denoise_tv_chambolle(np.ma.masked_array(np.where(noisy.isnull(), trim_mean, noisy), noisy.isnull()),
weight=self.weight
)
return res
[docs]
def calibrated_tv_filter(self, n_weights=50, return_loss=False):
"""
Apply TV filter with `auto calibration <https://scikit-image.org/docs/0.25.x/auto_examples/filters/plot_j_invariant_tutorial.html>`_.
Parameters
----------
n_weights : int
Number of weights used for auto calibration.
return_loss : bool
Whether return the loss results.
Returns
-------
denoised_calibrated_tv : :class:`~xarray.DataArray`
2D denoised data field using calibrated parameters.
weights : :class:`numpy.ndarray`, optional
1D array of weights tested for calibration. Returned only if ``return_loss == True``.
losses_tv : :class:`numpy.ndarray`, optional
1D array of total variation (TV) filter losses. Returned only if ``return_loss == True``.
"""
noisy = self.data.squeeze().where(self.segmentation == self.seg_id)
# remove highest and lowest value
noisy_mask = self._create_mask_from_quantiles(noisy)
m = noisy_mask.isnull()
trim_mean = trimmed_mean(noisy_mask.stack(z=('y', 'x')).dropna('z'), (1e-3, 1e-3))
noisy_mask = np.ma.masked_array(np.where(m, trim_mean, noisy_mask), m)
noise_std = np.std(noisy_mask)
weight_range = (noise_std/10, noise_std*3)
weights = np.linspace(weight_range[0], weight_range[1], n_weights)
parameter_ranges_tv = {'weight': weights}
_, (parameters_tested_tv, losses_tv) = calibrate_denoiser(
noisy_mask,
denoise_tv_chambolle,
denoise_parameters=parameter_ranges_tv,
extra_output=True,
)
LOG.debug(f'Minimum self-supervised loss TV: {np.min(losses_tv):.3f}')
best_parameters_tv = parameters_tested_tv[np.argmin(losses_tv)]
LOG.debug(f'best_parameters_tv: {best_parameters_tv}')
self.weight = np.round(best_parameters_tv['weight'], 1)
denoised_calibrated_tv = denoise_invariant(
np.ma.masked_array(np.where(noisy.isnull(), trim_mean, noisy), noisy.isnull()),
denoise_tv_chambolle, denoiser_kwargs=best_parameters_tv,
)
if return_loss:
return denoised_calibrated_tv, weights, losses_tv
else:
return denoised_calibrated_tv
[docs]
def smooth(self):
"""Smooth data by TV filter."""
# create the empty list for denoised data
res_list = []
# denoising data by cluster
for seg_id in np.unique(self.segmentation):
LOG.info(f'Applying denoising to segmentation_id {seg_id} ...')
self.seg_id = seg_id
if self.method == 'tv_filter':
res = self.tv_filter()
elif self.method == 'calibrated_tv_filter':
res = self.calibrated_tv_filter()
else:
raise ValueError(f'{self.method} is not supported yet.')
# copy attributes
res = self._copy_attrs(res)
# set values only for seg_id
res = res.where(self.segmentation == seg_id, 0)
# append to list
res_list.append(res)
# aggregate all results into one DataArray
res_sum = sum(res_list)
# copy attributes
res_sum = self._copy_attrs(res_sum)
return res_sum