Add enhance_image4 function for advanced image enhancement using MSRCR
This commit is contained in:
		
							
								
								
									
										70
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								main.py
									
									
									
									
									
								
							@@ -308,6 +308,71 @@ def enhance_image3(image: bytes) -> bytes:
 | 
				
			|||||||
    return enhanced_webp.tobytes()
 | 
					    return enhanced_webp.tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def enhance_image4(image: bytes) -> bytes:  # noqa: PLR0914
 | 
				
			||||||
 | 
					    """Enhance an image using a simplified Multi-Scale Retinex with Color Restoration (MSRCR) algorithm to better reveal details in dark areas.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This approach first denoises the image, then applies multi-scale retinex processing to
 | 
				
			||||||
 | 
					    boost local contrast. A color restoration step helps maintain a natural look.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        image (bytes): The input image to enhance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        bytes: The enhanced image encoded in WebP format.
 | 
				
			||||||
 | 
					    """  # noqa: E501
 | 
				
			||||||
 | 
					    # Decode the image from bytes
 | 
				
			||||||
 | 
					    nparr = np.frombuffer(image, np.uint8)
 | 
				
			||||||
 | 
					    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Denoise the image with conservative settings
 | 
				
			||||||
 | 
					    img = cv2.fastNlMeansDenoisingColored(img, None, 5, 5, 7, 21)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Convert to float64 and add 1 to avoid log(0)
 | 
				
			||||||
 | 
					    img = img.astype(np.float64) + 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Define scales for the multi-scale retinex (you can experiment with these)
 | 
				
			||||||
 | 
					    scales = [15, 80, 250]
 | 
				
			||||||
 | 
					    retinex = np.zeros_like(img)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute the retinex output over different scales
 | 
				
			||||||
 | 
					    for sigma in scales:
 | 
				
			||||||
 | 
					        # Gaussian blur with standard deviation sigma; kernel size is computed automatically
 | 
				
			||||||
 | 
					        blur = cv2.GaussianBlur(img, (0, 0), sigma)
 | 
				
			||||||
 | 
					        retinex += np.log(img) - np.log(blur)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Average the retinex result over all scales
 | 
				
			||||||
 | 
					    retinex /= len(scales)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Color Restoration Step ---
 | 
				
			||||||
 | 
					    # Compute the sum across color channels (with a small epsilon to avoid division by zero)
 | 
				
			||||||
 | 
					    eps = 1e-6
 | 
				
			||||||
 | 
					    sum_channels = np.sum(img, axis=2, keepdims=True) + eps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # The color restoration factor; alpha is chosen empirically (here 125 works well in many cases)
 | 
				
			||||||
 | 
					    color_restoration = np.log(125 * img / sum_channels + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Combine the retinex output with the color restoration factor
 | 
				
			||||||
 | 
					    msrcr = retinex * color_restoration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Apply gain and offset adjustments to fine-tune brightness and contrast
 | 
				
			||||||
 | 
					    gain = 1.5
 | 
				
			||||||
 | 
					    offset = 20
 | 
				
			||||||
 | 
					    msrcr = msrcr * gain + offset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Normalize each channel to span the full 0-255 range
 | 
				
			||||||
 | 
					    for channel in range(3):
 | 
				
			||||||
 | 
					        ch_data = msrcr[:, :, channel]
 | 
				
			||||||
 | 
					        ch_min, ch_max = ch_data.min(), ch_data.max()
 | 
				
			||||||
 | 
					        msrcr[:, :, channel] = ((ch_data - ch_min) / (ch_max - ch_min + eps)) * 255
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Clip the values to valid 8-bit range and convert back to uint8
 | 
				
			||||||
 | 
					    enhanced_img = np.clip(msrcr, 0, 255).astype(np.uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Encode the enhanced image to WebP format
 | 
				
			||||||
 | 
					    _, enhanced_webp = cv2.imencode(".webp", enhanced_img)
 | 
				
			||||||
 | 
					    return enhanced_webp.tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@client.tree.context_menu(name="Enhance Image")
 | 
					@client.tree.context_menu(name="Enhance Image")
 | 
				
			||||||
@app_commands.allowed_installs(guilds=True, users=True)
 | 
					@app_commands.allowed_installs(guilds=True, users=True)
 | 
				
			||||||
@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
 | 
					@app_commands.allowed_contexts(guilds=True, dms=True, private_channels=True)
 | 
				
			||||||
@@ -350,7 +415,10 @@ async def enhance_image_command(interaction: discord.Interaction, message: disco
 | 
				
			|||||||
        enhanced_image3: bytes = enhance_image3(image_bytes)
 | 
					        enhanced_image3: bytes = enhance_image3(image_bytes)
 | 
				
			||||||
        file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
 | 
					        file3 = discord.File(fp=io.BytesIO(enhanced_image3), filename=f"enhanced3-{timestamp}.webp")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await interaction.followup.send("Enhanced version:", files=[file1, file2, file3])
 | 
					        enhanced_image4: bytes = enhance_image4(image_bytes)
 | 
				
			||||||
 | 
					        file4 = discord.File(fp=io.BytesIO(enhanced_image4), filename=f"enhanced4-{timestamp}.webp")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        await interaction.followup.send("Enhanced version:", files=[file1, file2, file3, file4])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    except (httpx.HTTPError, openai.OpenAIError) as e:
 | 
					    except (httpx.HTTPError, openai.OpenAIError) as e:
 | 
				
			||||||
        logger.exception("Failed to enhance image")
 | 
					        logger.exception("Failed to enhance image")
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user