博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
识花模型代码理解
阅读量:2457 次
发布时间:2019-05-10

本文共 4880 字,大约阅读时间需要 16 分钟。

import osimport numpy as npimport tensorflow as tffrom tensorflow_vgg import vgg16from tensorflow_vgg import utils
data_dir = 'flower_photos/'contents = os.listdir(data_dir)classes = [each for each in contents if os.path.isdir(data_dir + each)] #os.path.isdir()判断某一路径是否为目录
# 首先设置计算batch的值,如果运算平台的内存越大,这个值可以设置得越高batch_size = 10# 用codes_list来存储特征值codes_list = []# 用labels来存储花的类别labels = []# batch数组用来临时存储图片数据batch = []codes = Nonewith tf.Session() as sess:    # 构建VGG16模型对象    vgg = vgg16.Vgg16()    input_ = tf.placeholder(tf.float32, [None, 224, 224, 3]) #None表示不定    with tf.name_scope("content_vgg"):  #主要目的是更加方便的管理参数命名        # 载入VGG16模型        vgg.build(input_)        # 对每个不同种类的花分别用VGG16计算特征值    for each in classes:        print("Starting {} images".format(each))        class_path = data_dir + each        files = os.listdir(class_path)        for ii, file in enumerate(files, 1):            # 载入图片并放入batch数组中            img = utils.load_image(os.path.join(class_path, file))            batch.append(img.reshape((1, 224, 224, 3)))            labels.append(each)                        # 如果图片数量到了batch_size则开始具体的运算            if ii % batch_size == 0 or ii == len(files):                images = np.concatenate(batch)                feed_dict = {input_: images} #feed_dict给使用的placeholder创建出来的tensor赋值                # 计算特征值                codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)                                # 将结果放入到codes数组中                if codes is None:                    codes = codes_batch                else:                    codes = np.concatenate((codes, codes_batch))                                # 清空数组准备下一个batch的计算                batch = []                print('{} images processed'.format(ii))

with open('codes', 'w') as f:    codes.tofile(f) #tofile()将数组中的数据以二进制格式写进文件    import csvwith open('labels', 'w') as f:    writer = csv.writer(f, delimiter='\n') #默认的情况下, 读和写使用逗号做分隔符(delimiter),用双引号作为引用符(quotechar),当遇到特殊情况是,可以根据需要手动指定字符    writer.writerow(labels)
from sklearn.preprocessing import LabelBinarizerlb = LabelBinarizer()lb.fit(labels) #等价于?lb.fit_transform(labels)labels_vecs = lb.transform(labels)
from sklearn.model_selection import StratifiedShuffleSplitss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) #1组,测试占20%train_idx, val_idx = next(ss.split(codes, labels)) #分别将codes,labels按照ss的标准分割成80%的train_idx,和20%的val_idxhalf_val_len = int(len(val_idx)/2) #20%的val_idx进一步分割成1:1val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]train_x, train_y = codes[train_idx], labels_vecs[train_idx]val_x, val_y = codes[val_idx], labels_vecs[val_idx]test_x, test_y = codes[test_idx], labels_vecs[test_idx]print("Train shapes (x, y):", train_x.shape, train_y.shape)print("Validation shapes (x, y):", val_x.shape, val_y.shape)print("Test shapes (x, y):", test_x.shape, test_y.shape)
input_ = tf.placeholder(tf.float32,shape = [None,codes.shape[1]])labels_ = tf.placeholder(tf.int64,shape = [None,labels_vecs.shape[1]])fc = tf.contrib.layers.fully_connected(inputs_,256)logits = tf.contrib.layers.fully_connected(fc,labels_vecs.shape[1],activation_fn = None)cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = labels_,logits = logits)cost = tf.reduce_mean(cross_entropy)optimizer = tf.train.AdamOptimizer().minimize(cost)predicted = tf.nn.softmax(logits)correct_pred = tf.equal(tf.argmax(predicted,1),tf.argmax(labels_,1))accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))def get_batches(x,y,n_batches = 10):    batch_size = len(x)  // n_batches    for ii in range(0,n_batches * batch_size,batch_size):        if ii != (n_batches - 1) * batch_size:            x,y = x[ii:ii + batch_size],y[ii : batch_size]        else:            x,y = x[ii:],y[ii:]        yield x,yepochs = 20iteration = 0saver = tf.train.Saver()with tf.Sessioon() as sess:    sess.run(tf.global_variables_initializer())    for e in range(epochs):        for x,y in get_batches(train_x,train_y):            feed = {inputs_:x,                    labels_:y}            loss,_ = sess.run([cost,optimizer],feed_dict = feed)            print('Epoch:{}/{}'.format(e + 1,epochs),                  'Iteration:{}'.format(iteration),                  'Training loss: {;.5f}'.format(loss))            iteration += 1            if iteration % 5 == 0:                feed = {input_:val_x,                        labels_;val_y}                val_acc = sess.run(accuracy,feed_dict = feed)                print('Epoch:{}/{}'.format(e, epochs),                      'Iteration:{}'.format(iteration),                      'Validation Acc: {;.4f}'.format(val_acc))saver.save(sess,'checkpoints/flowers.ckpt')with tf.Session() as sess:    saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))    feed = {inputs_:test_x,            labels_:test_y}    test_acc = sess.run(accuracy,feed_dict = feed)    print('Test accuracy : {:.4f}'.format(test_acc))

转载地址:http://rpnhb.baihongyu.com/

你可能感兴趣的文章
ansible剧本如何写_我学过的3课:写Ansible剧本
查看>>
bash 脚本部署lmnp_使用Bash自动化Helm部署
查看>>
linux 中移动文件_如何在Linux中移动文件
查看>>
ansible 模块_您需要知道的10个Ansible模块
查看>>
无处不在_Kubernetes几乎无处不在,正在使用Java以及更多的行业趋势
查看>>
ansible 中文文档_浏览Ansible文档,自动执行补丁,虚拟化以及更多新闻
查看>>
人脸关键点 开源数据_谦虚是开源成功的关键,Kubernetes安全斗争以及更多行业趋势...
查看>>
markdown_Markdown初学者备忘单
查看>>
devops失败的原因_失败是无可指责的DevOps的功能
查看>>
开源项目演示_3种开源工具可让您的演示文稿流行
查看>>
rust编程语言_Mozilla的Rust编程语言处于关键阶段
查看>>
kicad阻焊层 设计_使用开源工具KiCad设计的footSHIELD
查看>>
开源项目如何本地更新_本地化开源项目的3个技巧
查看>>
唱吧录制的歌曲转换成mp3_录制开放文化歌曲
查看>>
Mercy Health为其主要门户网站设置了Drupal和Alfresco
查看>>
gpl2 gpl3区别_自由软件基金会将举办有关GPL执法和法律道德的研讨会
查看>>
python 下三角矩阵_Python | 矩阵的上三角
查看>>
Java StringBuffer CharSequence subSequence(int spos,int epos)方法与示例
查看>>
Java Collections unmodifiableList()方法与示例
查看>>
python 示例_Python日历类| itermonthdates()方法与示例
查看>>