zoukankan      html  css  js  c++  java
  • 画出8个高斯分布散点图

    import matplotlib.pyplot as plt
    import numpy as np
    
    num_mixtures = 8
    radius = 2.0
    std = 0.02
    thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures]
    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
    mix_coeffs=tuple([1 / num_mixtures] * num_mixtures)
    mean=tuple(zip(xs, ys))
    cov=tuple([(std, std)] * num_mixtures)
    ax = None
    epoch = 0
    fig = None
            
    
    def gmm_sample(num_samples, mix_coeffs, mean, cov):
        z = np.random.multinomial(num_samples, mix_coeffs)
        samples = np.zeros(shape=[num_samples, len(mean[0])])
        i_start = 0
        for i in range(len(mix_coeffs)):
            i_end = i_start + z[i]
            samples[i_start:i_end, :] = np.random.multivariate_normal(
                mean=np.array(mean)[i, :],
                cov=np.diag(np.array(cov)[i, :]),
                size=z[i])
            i_start = i_end
        return samples
    
    def disp_scatter(x, fig=None, ax=None):
        if ax is None:
            fig, ax = plt.subplots()
        ax.scatter(x[:, 0], x[:, 1], s=10, marker='+', color='r', alpha=0.8, label='real data')
        
        ax.legend()
        return fig, ax
    num_samples=1000
    
    x = gmm_sample(num_samples, mix_coeffs, mean, cov)
                   
    fig, ax = disp_scatter(x, fig=None, ax=None)
    fig.tight_layout()
    fig.savefig("output{}.png".format(epoch))

    num_mixtures = 8

     num_mixtures = 1

  • 相关阅读:
    Redis 安装
    Git的安装和使用
    HTML5 本地存储+layer弹层组件制作记事本
    PHP 微信公众号开发
    PHP 微信公众号开发
    Electron 安装与使用
    HTML5 桌面消息提醒
    Composer安装和使用
    玄 学
    区间内的真素数
  • 原文地址:https://www.cnblogs.com/gaona666/p/12446784.html
Copyright © 2011-2022 走看看