本文共 4880 字,大约阅读时间需要 16 分钟。
import osimport numpy as npimport tensorflow as tffrom tensorflow_vgg import vgg16from tensorflow_vgg import utils
data_dir = 'flower_photos/'contents = os.listdir(data_dir)classes = [each for each in contents if os.path.isdir(data_dir + each)] #os.path.isdir()判断某一路径是否为目录
# 首先设置计算batch的值,如果运算平台的内存越大,这个值可以设置得越高batch_size = 10# 用codes_list来存储特征值codes_list = []# 用labels来存储花的类别labels = []# batch数组用来临时存储图片数据batch = []codes = Nonewith tf.Session() as sess: # 构建VGG16模型对象 vgg = vgg16.Vgg16() input_ = tf.placeholder(tf.float32, [None, 224, 224, 3]) #None表示不定 with tf.name_scope("content_vgg"): #主要目的是更加方便的管理参数命名 # 载入VGG16模型 vgg.build(input_) # 对每个不同种类的花分别用VGG16计算特征值 for each in classes: print("Starting {} images".format(each)) class_path = data_dir + each files = os.listdir(class_path) for ii, file in enumerate(files, 1): # 载入图片并放入batch数组中 img = utils.load_image(os.path.join(class_path, file)) batch.append(img.reshape((1, 224, 224, 3))) labels.append(each) # 如果图片数量到了batch_size则开始具体的运算 if ii % batch_size == 0 or ii == len(files): images = np.concatenate(batch) feed_dict = {input_: images} #feed_dict给使用的placeholder创建出来的tensor赋值 # 计算特征值 codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict) # 将结果放入到codes数组中 if codes is None: codes = codes_batch else: codes = np.concatenate((codes, codes_batch)) # 清空数组准备下一个batch的计算 batch = [] print('{} images processed'.format(ii))
with open('codes', 'w') as f: codes.tofile(f) #tofile()将数组中的数据以二进制格式写进文件 import csvwith open('labels', 'w') as f: writer = csv.writer(f, delimiter='\n') #默认的情况下, 读和写使用逗号做分隔符(delimiter),用双引号作为引用符(quotechar),当遇到特殊情况是,可以根据需要手动指定字符 writer.writerow(labels)
from sklearn.preprocessing import LabelBinarizerlb = LabelBinarizer()lb.fit(labels) #等价于?lb.fit_transform(labels)labels_vecs = lb.transform(labels)
from sklearn.model_selection import StratifiedShuffleSplitss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) #1组,测试占20%train_idx, val_idx = next(ss.split(codes, labels)) #分别将codes,labels按照ss的标准分割成80%的train_idx,和20%的val_idxhalf_val_len = int(len(val_idx)/2) #20%的val_idx进一步分割成1:1val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]train_x, train_y = codes[train_idx], labels_vecs[train_idx]val_x, val_y = codes[val_idx], labels_vecs[val_idx]test_x, test_y = codes[test_idx], labels_vecs[test_idx]print("Train shapes (x, y):", train_x.shape, train_y.shape)print("Validation shapes (x, y):", val_x.shape, val_y.shape)print("Test shapes (x, y):", test_x.shape, test_y.shape)
input_ = tf.placeholder(tf.float32,shape = [None,codes.shape[1]])labels_ = tf.placeholder(tf.int64,shape = [None,labels_vecs.shape[1]])fc = tf.contrib.layers.fully_connected(inputs_,256)logits = tf.contrib.layers.fully_connected(fc,labels_vecs.shape[1],activation_fn = None)cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = labels_,logits = logits)cost = tf.reduce_mean(cross_entropy)optimizer = tf.train.AdamOptimizer().minimize(cost)predicted = tf.nn.softmax(logits)correct_pred = tf.equal(tf.argmax(predicted,1),tf.argmax(labels_,1))accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))def get_batches(x,y,n_batches = 10): batch_size = len(x) // n_batches for ii in range(0,n_batches * batch_size,batch_size): if ii != (n_batches - 1) * batch_size: x,y = x[ii:ii + batch_size],y[ii : batch_size] else: x,y = x[ii:],y[ii:] yield x,yepochs = 20iteration = 0saver = tf.train.Saver()with tf.Sessioon() as sess: sess.run(tf.global_variables_initializer()) for e in range(epochs): for x,y in get_batches(train_x,train_y): feed = {inputs_:x, labels_:y} loss,_ = sess.run([cost,optimizer],feed_dict = feed) print('Epoch:{}/{}'.format(e + 1,epochs), 'Iteration:{}'.format(iteration), 'Training loss: {;.5f}'.format(loss)) iteration += 1 if iteration % 5 == 0: feed = {input_:val_x, labels_;val_y} val_acc = sess.run(accuracy,feed_dict = feed) print('Epoch:{}/{}'.format(e, epochs), 'Iteration:{}'.format(iteration), 'Validation Acc: {;.4f}'.format(val_acc))saver.save(sess,'checkpoints/flowers.ckpt')with tf.Session() as sess: saver.restore(sess,tf.train.latest_checkpoint('checkpoints')) feed = {inputs_:test_x, labels_:test_y} test_acc = sess.run(accuracy,feed_dict = feed) print('Test accuracy : {:.4f}'.format(test_acc))
转载地址:http://rpnhb.baihongyu.com/