From 341c639dce4034d59a66d15764024abbe05b9035 Mon Sep 17 00:00:00 2001 From: Gasper Spagnolo Date: Sat, 29 Oct 2022 16:57:29 +0200 Subject: [PATCH] Koncno dela --- assignment2/solution.py | 21 ++++++++++++--- assignment2/uz_framework/image.py | 45 ++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/assignment2/solution.py b/assignment2/solution.py index 731c18f..10d007d 100644 --- a/assignment2/solution.py +++ b/assignment2/solution.py @@ -10,8 +10,9 @@ import uz_framework.image as uz_image ################################################################# def ex1(): one_a() + #one_b() -def one_a(): +def one_a() -> npt.NDArray[np.float64]: """ Firstly, you will implement the function myhist3 that computes a 3-D histogram from a three channel image. The images you will use are RGB, but the function @@ -24,10 +25,22 @@ def one_a(): the resulting histogram. """ test_image = uz_image.imread('./data/images/museum.jpg', uz_image.ImageType.float64) - uz_image.get_image_bins_ND(test_image, 10) - plt.imshow(test_image) - plt.show() + bins = uz_image.get_image_bins_ND(test_image, 10) + return bins +def one_b(): + """ + In order to perform image comparison using histograms, we need to implement + some distance measures. These are defined for two input histograms and return a + single scalar value that represents the similarity (or distance) between the two histograms. + Implement a function compare_histograms that accepts two histograms + and a string that identifies the distance measure you wish to calculate + Implement L2 metric, chi-square distance, intersection and Hellinger distance. + """ + test_image = uz_image.imread('./data/images/museum.jpg', uz_image.ImageType.float64) + bins = uz_image.get_image_bins_ND(test_image, 10) + + uz_image.compare_two_histograms(bins[0], bins[1], uz_image.DistanceMeasure.chi_square_distance) # ######## # # SOLUTION # diff --git a/assignment2/uz_framework/image.py b/assignment2/uz_framework/image.py index 550676e..b9a3bdf 100644 --- a/assignment2/uz_framework/image.py +++ b/assignment2/uz_framework/image.py @@ -11,6 +11,12 @@ class ImageType(enum.Enum): uint8 = 0 float64 = 1 +class DistanceMeasure(enum.Enum): + euclidian_distance = 0 + chi_square_distance = 1 + intersection_distance = 2 + hellinger_distance = 3 + def imread(path: str, type: ImageType) -> Union[npt.NDArray[np.float64], npt.NDArray[np.uint8]]: """ Reads an image in RGB order. Image type is transformed from uint8 to float, and @@ -210,12 +216,43 @@ def get_image_bins(image: Union[npt.NDArray[np.float64], npt.NDArray[np.uint8]] for i in empty_bins: counts = np.insert(counts, i - 1, 0) - return counts / np.sum(counts) + return counts def get_image_bins_ND(image: Union[npt.NDArray[np.float64], npt.NDArray[np.uint8]], number_of_bins: int) -> npt.NDArray[np.float64]: - for dimension in range(image.shape[2]): - print(get_image_bins(image[dimension], number_of_bins)) - return None + + bs = [] + hist = np.zeros((number_of_bins, number_of_bins, number_of_bins)) + if image.dtype.type == np.uint8: + bins = np.linspace(0, 255, num=number_of_bins) + elif image.dtype.type == np.float64: + bins = np.linspace(0, 1, num=number_of_bins) + else: + raise Exception('Unsuported datatype!') + + for i in range(image.shape[2]): + v = image[:, :, i].reshape(-1) + bs.append(np.digitize(v, bins).astype(np.uint32)) + + for i in range(len(bs[0])): + hist[bs[2][i] -1, bs[1][i] -1, bs[0][i] - 1] += 1 + + return hist / np.sum(hist) + +def compare_two_histograms(h1: npt.NDArray[np.float64], h2: npt.NDArray[np.float64], method: DistanceMeasure) -> float: + + if method == DistanceMeasure.euclidian_distance: + d = np.sqrt(np.sum(np.square(h1 - h2))) + elif method == DistanceMeasure.chi_square_distance: + d = 0.5 * np.sum(np.square(h1 - h2) / (h1 + h2 + np.finfo(float).eps)) + elif method == DistanceMeasure.intersection_distance: + d = 0.0 + elif method == DistanceMeasure.hellinger_distance: + d = 0.0 + else: + raise Exception('Unsuported method chosen!') + + + return d def apply_mask_on_image(image: Union[npt.NDArray[np.float64], npt.NDArray[np.uint8]], mask: npt.NDArray[np.uint8]):