zoukankan      html  css  js  c++  java
  • 常见代码

    1.检测文件/目录 是否存在

    from os.path import isfile, isdir
    if not isdir(vgg_dir): raise Exception("VGG directory doesn't exist!") vgg_dir = 'tensorflow_vgg/' if not isdir(vgg_dir):   raise Exception("VGG directory doesn't exist!")

    列出指定目录下的文件及遍历 目录名

    import os
    data_dir = 'flower_photos/'
    contents = os.listdir(data_dir)
    print(contents)
    classes = [each for each in contents if os.path.isdir(data_dir + each)]
    print(classes)
    ['daisy', 'dandelion', 'LICENSE.txt', 'roses', 'sunflowers', 'tulips']
    ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

    2.进度条:

    https://blog.csdn.net/qq_40666028/article/details/79335961

    from tqdm import trange
    import time
    for i in trange(200):
        time.sleep(0.1)    
    from tqdm import tqdm
    from urllib.request import urlretrieve
    class DLProgress(tqdm): # 继承tqdm类 last_block = 0 def hook(self, block_num=1, block_size=1, total_size=None): self.total = total_size self.update((block_num - self.last_block) * block_size) self.last_block = block_num with DLProgress(unit='B', unit_scale=True, miniters=1, desc='VGG16 Parameters') as pbar: ''' urlretrieve(url, filename=None, reporthook=None, data=None)方法直接将远程数据下载到本地 filename指定了保存本地路径(如果参数未指定,urllib会生成一个临时文件保存数据。 reporthook是一个回调函数,当连接上服务器、以及相应的数据块传输完毕时会触发该回调,我们可以利用这个回调函数来显示当前的下载进度。 data指post导服务器的数据,该方法返回一个包含两个元素的(filename, headers) 元组,filename 表示保存到本地的路径,header表示服务器的响应头 ''' urlretrieve( 'https://s3.amazonaws.com/content.udacity-data.com/nd101/vgg16.npy', vgg_dir + 'vgg16.npy', pbar.hook)
    #!/usr/bin/env python
    # coding=utf-8
    import os
    import urllib
    
    def cbk(a,b,c):
        '''回调函数
        @a:已经下载的数据块
        @b:数据块的大小
        @c:远程文件的大小
        '''
        per=100.0*a*b/c
        if per>100:
            per=100
        print '%.2f%%' % per
    
    url='http://www.python.org/ftp/python/2.7.5/Python-2.7.5.tar.bz2'
    dir=os.path.abspath('.')
    work_path=os.path.join(dir,'Python-2.7.5.tar.bz2')
    urllib.urlretrieve(url,work_path,cbk)

    3.压缩和解压缩

    import tarfile
    
    dataset_folder_path = 'flower_photos'
    
    #先下载到当前目录 class DLProgress(tqdm): last_block = 0 def hook(self, block_num=1, block_size=1, total_size=None): self.total = total_size self.update((block_num - self.last_block) * block_size) self.last_block = block_num if not isfile('flower_photos.tar.gz'): with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Flowers Dataset') as pbar: urlretrieve( 'http://download.tensorflow.org/example_images/flower_photos.tgz', 'flower_photos.tar.gz', pbar.hook)
    #下载到当前目录后解压缩到当前目录
    if not isdir(dataset_folder_path): with tarfile.open('flower_photos.tar.gz') as tar: tar.extractall() tar.close()
    if not isdir('dir_path'):
        with ZipFile('imgs.zip', 'r') as zipf:   
            for name in tqdm(zipf.namelist()[:1000],desc='Extract files', unit='files'):
                zipf.extract(name, path='dir_path')
            zipf.close()

     4. numpy数组的保存 和加载

     一般tensorflow session.run后的结果是 numpy数组,可以保存到文件目录,以后可以加载

    # write codes to file
    with open('codes', 'w') as f:
        codes.tofile(f)
        
    # write labels to file
    import csv
    with open('labels', 'w') as f:
        writer = csv.writer(f, delimiter='
    ')
        writer.writerow(labels)
    # read codes and labels from file
    import csv
    
    with open('labels') as f:
        reader = csv.reader(f, delimiter='
    ')
    # squeeze() 去除大小为1的维度 https://blog.csdn.net/lqfarmer/article/details/73323449
    labels = np.array([each for each in reader if len(each) > 0]).squeeze()
    with open('codes') as f:
        codes = np.fromfile(f, dtype=np.float32)
        # 参考https://blog.csdn.net/weixin_39449570/article/details/78619196
        # -1 表示列数 为自动计算

    codes = codes.reshape((len(labels), -1))
                                                 

     5. One hot encoding

    # sklearn的方法:  三行搞定

    from
    sklearn.preprocessing import LabelBinarizer lb = LabelBinarizer() lb.fit(labels) labels_vecs = lb.transform(labels)
    # keras的方法: 一行搞定
    from
    keras.utils import np_utils # one-hot encode the labels num_classes = len(np.unique(y_train)) y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) 
    #  tensorflow的方法: 

    import
    tensorflow as tf
    CLASS= len(np.unique([0,1,2,3,4,5,6,7]))   label1
    =tf.constant([0,1,2,3,4,5,6,7]) sess1=tf.Session() print('label1:',sess1.run(label1)) b = tf.one_hot(label1,CLASS,1,0) with tf.Session() as sess: #sess.run(tf.global_variables_initializer()) sess.run(b) print('after one_hot',sess.run(b))


    # 核心4行搞定:
    label1=tf.constant([0,1,2,3,4,5,6,7])
    b = tf.one_hot(label1,8,1,0)
    with tf.Session() as sess:
    sess.run(b)
         
    结果:
    label1: [0 1 2 3 4 5 6 7] after one_hot [[1 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]]

    6. split 训练集  验证集 测试集

    from sklearn.model_selection import StratifiedShuffleSplit
    # https://blog.csdn.net/m0_38061927/article/details/76180541
    ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) # 分一次,训练 测试 8:2
    train_idx, val_idx = next(ss.split(codes, labels))
    
    half_val_len = int(len(val_idx)/2)
    val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]
    
    train_x, train_y = codes[train_idx], labels_vecs[train_idx]
    val_x, val_y = codes[val_idx], labels_vecs[val_idx]
    test_x, test_y = codes[test_idx], labels_vecs[test_idx]

    tf的sess 结果存到磁盘上

    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        。。。  
        。。。
       saver.save(sess, "checkpoints/flowers.ckpt")

    监测是否能用gpu,是返回true

    import tensorflow as tf
    tf.test.is_gpu_available(
        cuda_only=False,
        min_cuda_compute_capability=None
    )

    keras常用代码

    from keras.layers import Conv2D
    
    #卷积层 ,16个过滤器,过滤器大小 滑动strides默认为1,
    Conv2D(filters=16, kernel_size=2, strides=2, activation='relu', input_shape=(200, 200, 1))

    字符处理

    from string import punctuation # 标点符号!"#$%&'()*+,-./:;<=>?@[]^_`{|}~
    
    #遍历每一个字符,去除停用词(标点符号),再连接起来
    doc = ''.join([c for c in reviews if c not in punctuation]) 
    
    #去除换行符(用换行符分割,再用空格连起来)
    reviews = doc.split('
    ')  
    all_text = ' '.join(reviews)  
    
    #分词。默认分隔符为空格。
    words = all_text.split()
  • 相关阅读:
    《大话设计模式》的一些总结
    一个仿jdkd的动态代理
    一道笔试题(构造数组)
    c# 汉字转拼音
    IDEA常用插件盘点(香~~)
    服务器概念、应用服务器盘点大科普
    创建一个简单的Struts 2程序
    JAVA(Object类、Date类、Dateformat类、Calendar类)
    DQL查询语句和约束
    MySQL操作语句
  • 原文地址:https://www.cnblogs.com/lhuser/p/9218237.html
Copyright © 2011-2022 走看看