diff --git a/main.py b/main.py index 1c5d563..87bfedb 100644 --- a/main.py +++ b/main.py @@ -308,6 +308,71 @@ def enhance_image3(image: bytes) -> bytes: return enhanced_webp.tobytes() +def enhance_image4(image: bytes) -> bytes: # noqa: PLR0914 + """Enhance an image using a simplified Multi-Scale Retinex with Color Restoration (MSRCR) algorithm to better reveal details in dark areas. + + This approach first denoises the image, then applies multi-scale retinex processing to + boost local contrast. A color restoration step helps maintain a natural look. + + Args: + image (bytes): The input image to enhance. + + Returns: + bytes: The enhanced image encoded in WebP format. + """ # noqa: E501 + # Decode the image from bytes + nparr = np.frombuffer(image, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + # Denoise the image with conservative settings + img = cv2.fastNlMeansDenoisingColored(img, None, 5, 5, 7, 21) + + # Convert to float64 and add 1 to avoid log(0) + img = img.astype(np.float64) + 1.0 + + # Define scales for the multi-scale retinex (you can experiment with these) + scales = [15, 80, 250] + retinex = np.zeros_like(img) + + # Compute the retinex output over different scales + for sigma in scales: + # Gaussian blur with standard deviation sigma; kernel size is computed automatically + blur = cv2.GaussianBlur(img, (0, 0), sigma) + retinex += np.log(img) - np.log(blur) + + # Average the retinex result over all scales + retinex /= len(scales) + + # --- Color Restoration Step --- + # Compute the sum across color channels (with a small epsilon to avoid division by zero) + eps = 1e-6 + sum_channels = np.sum(img, axis=2, keepdims=True) + eps + + # The color restoration factor; alpha is chosen empirically (here 125 works well in many cases) + color_restoration = np.log(125 * img / sum_channels + 1) + + # Combine the retinex output with the color restoration factor + msrcr = retinex * color_restoration + + # Apply gain and offset adjustments to fine-tune brightness and contrast + gain = 1.5 + offset = 20 + msrcr = msrcr * gain + offset + + # Normalize each channel to span the full 0-255 range + for channel in range(3): + ch_data = msrcr[:, :, channel] + ch_min, ch_max = ch_data.min(), ch_data.max() + msrcr[:, :, channel] = ((ch_data - ch_min) / (ch_max - ch_min + eps)) * 255 + + # Clip the values to valid 8-bit range and convert back to uint8 + enhanced_img = np.clip(msrcr, 0, 255).astype(np.uint8) + + # Encode the enhanced image to WebP format + _, enhanced_webp = cv2.imencode(".webp", enhanced_img) + return enhanced_webp.tobytes() + + @client.tree.context_menu(name="Enhance Image") @app_commands.allowed_installs(guilds=True, users=True) @app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True) @@ -350,7 +415,10 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco enhanced_image3: bytes = enhance_image3(image_bytes) file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp") - await interaction.followup.send("Enhanced version:", files=[file1, file2, file3]) + enhanced_image4: bytes = enhance_image4(image_bytes) + file4 = discord.File(fp=io.BytesIO(enhanced_image4), filename=f"enhanced4-{timestamp}.webp") + + await interaction.followup.send("Enhanced version:", files=[file1, file2, file3, file4]) except (httpx.HTTPError, openai.OpenAIError) as e: logger.exception("Failed to enhance image")