首页Python【Python小程序】使用...

【Python小程序】使用sklearn通过KMeans聚类算法提取图像主题色

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1i9F6oV1J5oZnIsOASDs0gQ?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/Python_mini_program





之前文章介绍过一个用于从图像中抓取主色或有代表性调色板的Python第三方库Color Thief
【Python计算生态】Color Thief——从图像提取主色或调色板
本次将使用sklearn通过KMeans聚类算法提取图像主题色,具体操作步骤如下:
  • 使用OpenCV库读取图像并对其进行预处理;

  • 使用KMeans算法对图像中的像素进行聚类;
  • 定义RGB_histogram函数用于计算每个簇(颜色)在图像中的比例;
  • 定义plot_bar函数来绘制一个条形图,用于可视化聚类操作得到的每个簇的颜色及其在图像中所占的比例
  • 定义display_results函数使用matplotlib可视化原始图像和结果。
完整代码如下:
# 导入所需的库  
from sklearn.cluster import KMeans  # 导入KMeans聚类算法  
import matplotlib.pyplot as plt  # 导入matplotlib用于绘图  
import cv2  # 导入OpenCV用于图像处理  
import numpy as np  # 导入NumPy用于数值计算  

class ImageColorAnalyzer:  
    def __init__(self, image_path ,n_clusters=5):  
        self.image_path = image_path  # 图像路径
        self.n_clusters = n_clusters  # 初始化簇的数量    

    def RGB_histogram(self, labels):  
        """计算RGB直方图"""  
        # 计算唯一标签的数量,并创建一个包含这些标签的数组
        num_labels = np.arange(0, len(np.unique(labels)) + 1)
        # 计算直方图,即每个标签(簇)的频率 
        hist, _ = np.histogram(labels, bins=num_labels) 
        # 将直方图转换为浮点数
        hist = hist.astype("float")    
        # 归一化直方图,使所有频率的和为1 
        hist /= hist.sum()   
        return hist  

    def plot_bar(self, hist, centroids):  
        """绘制条形图"""  
        width = 600
        height = 150
        # 创建一个50x300像素的空白画布,用于绘制条形图
        bar = np.zeros((height, width, 3), dtype="uint8")  
        # 初始化条形图的起始X坐标  
        startX = 0  
        # 遍历直方图中的每个比例和对应的簇中心颜色
        for (percent, color) in zip(hist, centroids):   
            # 计算条形的结束X坐标 
            endX = startX + (percent * width)   
            # 绘制条形图
            cv2.rectangle(bar, (int(startX), 0), 
                          (int(endX), height),  
                          color.astype("uint8").tolist(), -1)   

            # 在条形上添加RGB文本  
            cv2.putText(bar, f"Color:" , (int(startX), 80), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) 
            cv2.putText(bar, f"R- {color.astype('int').tolist()[0]}" , 
                        (int(startX), 100), cv2.FONT_HERSHEY_SIMPLEX, 
                        0.5, (255, 0, 0), 1) 
            cv2.putText(bar, f"G- {color.astype('int').tolist()[1]}" , 
                        (int(startX), 120), cv2.FONT_HERSHEY_SIMPLEX, 
                        0.5, (255, 0, 0), 1) 
            cv2.putText(bar, f"B- {color.astype('int').tolist()[2]}" , 
                        (int(startX), 140), cv2.FONT_HERSHEY_SIMPLEX, 
                        0.5, (255, 0, 0), 1) 

            # 在条形上添加文本 
            cv2.putText(bar, f"Percent:", (int(startX), 20),  
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)  
            cv2.putText(bar, f"{percent:.2f}", (int(startX), 40),  
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
            # 更新起始X坐标为当前条形的结束坐标
            startX = endX  
        return bar

    def analyze_image(self):  
        """分析图像并返回直方图和条形图"""  
        # 读取图像文件 
        image = cv2.imread(self.image_path)    
        if image is None:  
            raise ValueError("无法读取图像文件,请检查路径是否正确。")

        # 将图像从BGR颜色空间转换为RGB颜色空间
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   
        # 将图像重塑为二维数组,每行代表一个像素的RGB值  
        img = image.reshape((image.shape[0] * image.shape[1], 3))  

        # 创建KMeans聚类器 
        clt = KMeans(n_clusters=self.n_clusters)  
        # 对图像像素进行聚类  
        clt.fit(img)  
        # 计算聚类结果的RGB直方图(使用clt.labels_获取训练数据所属的类别)
        hist = self.RGB_histogram(clt.labels_)  
        # 使用聚类中心和直方图绘制条形图(使用clt.cluster_centers_获取聚类中心)
        bar = self.plot_bar(hist, clt.cluster_centers_) 

        return image, bar

    def display_results(self, image, bar):  
        """显示分析结果"""  
        # 创建一个10x5的图形窗口 
        plt.figure(figsize=(20, 5))  
        # 添加一个1行2列的子图,并激活第一个子图
        plt.subplot(1, 2, 1)  
        # 显示原始图像  
        plt.imshow(image) 
        # 添加标题
        plt.title("Original Image") 
        plt.axis("off")
        # 激活第二个子图 
        plt.subplot(1, 2, 2)  
        # 显示条形图  
        plt.imshow(bar)  
        # 添加标题
        plt.title("Color Bar Chart")  
        plt.axis("off")
        # 显示图形窗口
        plt.show()  


    def process_and_display(self):  
        """处理图像并显示结果"""            
        # 分析图像  
        image , bar = self.analyze_image()  
        # 显示结果  
        self.display_results(image, bar) 

# 替换为您的图像文件路径  
image_path = "test.jpg"  
# 设置簇的数量
n_clusters = 5    
# 创建分析器实例 
analyzer = ImageColorAnalyzer(image_path , n_clusters=n_clusters)   
# 处理图像并显示结果
analyzer.process_and_display()

提出结果如下图所示:

本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments