import datetime import logging import os import numpy as np import tensorflow as tf from scipy import ndimage # 图像转化为n维数组 from tensorflow.contrib import rnn # 这是一种比较重要的Model保存方法 from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import ( signature_constants, signature_def_utils, tag_constants, utils) from tensorflow.python.util import compat logging.basicConfig( format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') # step 0 # 命令行传参用的 # tf.app.flags.DEFINE_string("param_name", "default_val", "description") flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') flags.DEFINE_integer('epoch_number', 10, 'Number of epochs to run trainer.') flags.DEFINE_integer("batch_size", 1024, "indicates batch size in a single gpu, default is 1024") flags.DEFINE_string("checkpoint_dir", "./checkpoint/", "indicates the checkpoint dirctory") flags.DEFINE_string("tensorboard_dir", "./tensorboard/", "indicates training output") flags.DEFINE_string("optimizer", "adam", "optimizer to train") flags.DEFINE_integer('steps_to_validate', 1, 'Steps to validate and print loss') flags.DEFINE_string("mode", "train", "Option mode: train, inference, savedmodel") flags.DEFINE_string("image", "./data/inference/Pikachu.png", "The image to inference") flags.DEFINE_string("checkpoint_path", "./checkpoint/", "Path for checkpoint") flags.DEFINE_string( "model", "cnn", "Model to train, option model: cnn, lstm, bidirectional_lstm, stacked_lstm" ) flags.DEFINE_string("model_path", "./model/", "Path of the model") flags.DEFINE_integer("model_version", 1, "Version of the model") def main(): # 初始化一些参数 print("Start Pokemon classifier") if os.path.exists(FLAGS.checkpoint_path) == False: os.makedirs(FLAGS.checkpoint_path) CHECKPOINT_FILE = FLAGS.checkpoint_path + "/checkpoint.ckpt" LATEST_CHECKPOINT = tf.train.latest_checkpoint(FLAGS.checkpoint_path) # Initialize train and test data TRAIN_IMAGE_NUMBER = 646 TEST_IMAGE_NUMBER = 68 IMAGE_SIZE = 32 RGB_CHANNEL_SIZE = 3 LABEL_SIZE = 17 train_dataset = np.ndarray( shape=(TRAIN_IMAGE_NUMBER, IMAGE_SIZE, IMAGE_SIZE, RGB_CHANNEL_SIZE), # channel last dtype=np.float32) test_dataset = np.ndarray( shape=(TEST_IMAGE_NUMBER, IMAGE_SIZE, IMAGE_SIZE, RGB_CHANNEL_SIZE), dtype=np.float32) train_labels = np.ndarray(shape=(TRAIN_IMAGE_NUMBER, ), dtype=np.int32) test_labels = np.ndarray(shape=(TEST_IMAGE_NUMBER, ), dtype=np.int32) TRAIN_DATA_DIR = "./data/train/" TEST_DATA_DIR = "./data/test/" VALIDATE_DATA_DIR = "./data/validate/" IMAGE_FORMAT = ".png" index = 0 #图像个数计数器 pokemon_type_id_map = { "Bug": 0, "Dark": 1, "Dragon": 2, "Electric": 3, "Fairy": 4, "Fighting": 5, "Fire": 6, "Ghost": 7, "Grass": 8, "Ground": 9, "Ice": 10, "Normal": 11, "Poison": 12, "Psychic": 13, "Rock": 14, "Steel": 15, "Water": 16 } pokemon_types = [ "Bug", "Dark", "Dragon", "Electric", "Fairy", "Fighting", "Fire", "Ghost", "Grass", "Ground", "Ice", "Normal", "Poison", "Psychic", "Rock", "Steel", "Water" ] # step 1加载训练数据 for pokemon_type in os.listdir(TRAIN_DATA_DIR): for image_filename in os.listdir( os.path.join(TRAIN_DATA_DIR, pokemon_type)): if image_filename.endswith(IMAGE_FORMAT): image_filepath = os.path.join(TRAIN_DATA_DIR, pokemon_type, image_filename) image_ndarray = ndimage.imread(image_filepath, mode="RGB") #RGB train_dataset[index] = image_ndarray train_labels[index] = pokemon_type_id_map.get(pokemon_type) # 把label转化成数值型 index += 1 index = 0 # step2 加载测试数据 for pokemon_type in os.listdir(TEST_DATA_DIR): for image_filename in os.listdir( os.path.join(TEST_DATA_DIR, pokemon_type)): if image_filename.endswith(IMAGE_FORMAT): image_filepath = os.path.join(TEST_DATA_DIR, pokemon_type, image_filename) image_ndarray = ndimage.imread(image_filepath, mode="RGB") test_dataset[index] = image_ndarray test_labels[index] = pokemon_type_id_map.get(pokemon_type) index += 1 # step3 定义model # placeholder keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1]) keys = tf.identity(keys_placeholder) # base64编码图像 model_base64_placeholder = tf.placeholder( shape=[None], dtype=tf.string, name="model_input_b64_images") model_base64_string = tf.decode_base64(model_base64_placeholder) # 等价于python的map() model_base64_input = tf.map_fn(lambda x: tf.image.resize_images(tf.image.decode_jpeg(x, channels=RGB_CHANNEL_SIZE), [IMAGE_SIZE, IMAGE_SIZE]), model_base64_string, dtype=tf.float32) x = tf.placeholder( tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, RGB_CHANNEL_SIZE)) y = tf.placeholder(tf.int32, shape=(None, )) batch_size = FLAGS.batch_size epoch_number = FLAGS.epoch_number checkpoint_dir = FLAGS.checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) tensorboard_dir = FLAGS.tensorboard_dir mode = FLAGS.mode checkpoint_file = checkpoint_dir + "/checkpoint.ckpt" steps_to_validate = FLAGS.steps_to_validate def cnn_inference(x): # Convolution layer result: [BATCH_SIZE, 16, 16, 64] # (n+2p-f)/s+1 with tf.variable_scope("conv1"): weights = tf.get_variable( "weights", [3, 3, 3, 32], initializer=tf.random_normal_initializer()) bias = tf.get_variable( "bias", [32], initializer=tf.random_normal_initializer()) layer = tf.nn.conv2d(x, weights, strides=[1, 1, 1, 1], padding="SAME") # 32*32*32 layer = tf.nn.bias_add(layer, bias) layer = tf.nn.relu(layer) # (n-f)/s+1 layer = tf.nn.max_pool( layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") #16*16*32 # Convolution layer result: [BATCH_SIZE, 8, 8, 64] with tf.variable_scope("conv2"): weights = tf.get_variable( "weights", [3, 3, 32, 64], initializer=tf.random_normal_initializer()) bias = tf.get_variable( "bias", [64], initializer=tf.random_normal_initializer()) layer = tf.nn.conv2d( layer, weights, strides=[1, 1, 1, 1], padding="SAME") #16*16*64 layer = tf.nn.bias_add(layer, bias) layer = tf.nn.relu(layer) layer = tf.nn.max_pool( layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") #8*8*64 # 拉直做全连接 layer = tf.reshape(layer, [-1, 8 * 8 * 64]) # Full connected layer result: [BATCH_SIZE, 17] with tf.variable_scope("fc1"): # weights.get_shape().as_list()[0]] = 8 * 8 * 64 weights = tf.get_variable( "weights", [8 * 8 * 64, LABEL_SIZE], initializer=tf.random_normal_initializer()) bias = tf.get_variable( "bias", [LABEL_SIZE], initializer=tf.random_normal_initializer()) layer = tf.add(tf.matmul(layer, weights), bias) return layer # 17个节点 def lstm_inference(x): RNN_HIDDEN_UNITS = 128 # x was [BATCH_SIZE, 32, 32, 3] # x changes to [32, BATCH_SIZE, 32, 3] x = tf.transpose(x, [1, 0, 2, 3]) # x changes to [32 * BATCH_SIZE, 32 * 3] x = tf.reshape(x, [-1, IMAGE_SIZE * RGB_CHANNEL_SIZE]) # x changes to array of 32 * [BATCH_SIZE, 32 * 3] x = tf.split(axis=0, num_or_size_splits=IMAGE_SIZE, value=x) weights = tf.Variable(tf.random_normal([RNN_HIDDEN_UNITS, LABEL_SIZE])) biases = tf.Variable(tf.random_normal([LABEL_SIZE])) # output size is 128, state size is (c=128, h=128) lstm_cell = rnn.BasicLSTMCell(RNN_HIDDEN_UNITS, forget_bias=1.0) # outputs is array of 32 * [BATCH_SIZE, 128] outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32) # outputs[-1] is [BATCH_SIZE, 128] return tf.matmul(outputs[-1], weights) + biases def bidirectional_lstm_inference(x): RNN_HIDDEN_UNITS = 128 # x was [BATCH_SIZE, 32, 32, 3] # x changes to [32, BATCH_SIZE, 32, 3] x = tf.transpose(x, [1, 0, 2, 3]) # x changes to [32 * BATCH_SIZE, 32 * 3] x = tf.reshape(x, [-1, IMAGE_SIZE * RGB_CHANNEL_SIZE]) # x changes to array of 32 * [BATCH_SIZE, 32 * 3] x = tf.split(axis=0, num_or_size_splits=IMAGE_SIZE, value=x) weights = tf.Variable(tf.random_normal([2 * RNN_HIDDEN_UNITS, LABEL_SIZE])) biases = tf.Variable(tf.random_normal([LABEL_SIZE])) # output size is 128, state size is (c=128, h=128) fw_lstm_cell = rnn.BasicLSTMCell(RNN_HIDDEN_UNITS, forget_bias=1.0) bw_lstm_cell = rnn.BasicLSTMCell(RNN_HIDDEN_UNITS, forget_bias=1.0) # outputs is array of 32 * [BATCH_SIZE, 128] outputs, _, _ = rnn.bidirectional_rnn( fw_lstm_cell, bw_lstm_cell, x, dtype=tf.float32) # outputs[-1] is [BATCH_SIZE, 128] return tf.matmul(outputs[-1], weights) + biases def stacked_lstm_inference(x): ''' lstm_inference(x) ''' RNN_HIDDEN_UNITS = 128 # x was [BATCH_SIZE, 32, 32, 3] # x changes to [32, BATCH_SIZE, 32, 3] x = tf.transpose(x, [1, 0, 2, 3]) # x changes to [32 * BATCH_SIZE, 32 * 3] x = tf.reshape(x, [-1, IMAGE_SIZE * RGB_CHANNEL_SIZE]) # x changes to array of 32 * [BATCH_SIZE, 32 * 3] x = tf.split(axis=0, num_or_size_splits=IMAGE_SIZE, value=x) weights = tf.Variable(tf.random_normal([RNN_HIDDEN_UNITS, LABEL_SIZE])) biases = tf.Variable(tf.random_normal([LABEL_SIZE])) # output size is 128, state size is (c=128, h=128) lstm_cell = rnn.BasicLSTMCell(RNN_HIDDEN_UNITS, forget_bias=1.0) lstm_cells = rnn.MultiRNNCell([lstm_cell] * 2) # 2层 # outputs is array of 32 * [BATCH_SIZE, 128] outputs, states = rnn.rnn(lstm_cells, x, dtype=tf.float32) # outputs[-1] is [BATCH_SIZE, 128] return tf.matmul(outputs[-1], weights) + biases def inference(inputs): print("Use the model: {}".format(FLAGS.model)) if FLAGS.model == "cnn": return cnn_inference(inputs) elif FLAGS.model == "lstm": return lstm_inference(inputs) elif FLAGS.model == "bidirectional_lstm": return bidirectional_lstm_inference(inputs) elif FLAGS.model == "stacked_lstm": return stacked_lstm_inference(inputs) else: print("Unknow model, exit now") exit(1) # 定义训练操作 logit = inference(x) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=y)) learning_rate = FLAGS.learning_rate print("Use the optimizer: {}".format(FLAGS.optimizer)) if FLAGS.optimizer == "sgd": optimizer = tf.train.GradientDescentOptimizer(learning_rate) elif FLAGS.optimizer == "adadelta": optimizer = tf.train.AdadeltaOptimizer(learning_rate) elif FLAGS.optimizer == "adagrad": optimizer = tf.train.AdagradOptimizer(learning_rate) elif FLAGS.optimizer == "adam": optimizer = tf.train.AdamOptimizer(learning_rate) elif FLAGS.optimizer == "ftrl": optimizer = tf.train.FtrlOptimizer(learning_rate) elif FLAGS.optimizer == "rmsprop": optimizer = tf.train.RMSPropOptimizer(learning_rate) else: print("Unknow optimizer: {}, exit now".format(FLAGS.optimizer)) exit(1) global_step = tf.Variable(0, name='global_step', trainable=False) train_op = optimizer.minimize(loss, global_step=global_step) # Define accuracy and inference op tf.get_variable_scope().reuse_variables() #logits = inference(x) inference_logits = inference(model_base64_input) inference_predict_softmax = tf.nn.softmax(inference_logits) inference_predict_op = tf.argmax(inference_predict_softmax, 1) inference_correct_prediction = tf.equal(inference_predict_op, tf.to_int64(y)) inference_accuracy_op = tf.reduce_mean( tf.cast(inference_correct_prediction, tf.float32)) model_signature = signature_def_utils.build_signature_def( inputs={"images": utils.build_tensor_info(model_base64_placeholder)}, outputs={ "softmax": utils.build_tensor_info(inference_predict_softmax), "prediction": utils.build_tensor_info(inference_predict_op) }, method_name=signature_constants.PREDICT_METHOD_NAME) # 我们最常用的一种Model保存办法 saver = tf.train.Saver() tf.summary.scalar('loss', loss) init_op = tf.global_variables_initializer() # Create session to run graph with tf.Session() as sess: summary_op = tf.summary.merge_all() writer = tf.summary.FileWriter(tensorboard_dir, sess.graph) sess.run(init_op) if mode == "train": ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # 支持断点训练 if ckpt and ckpt.model_checkpoint_path: logging.info("Continue training from the model {}".format( ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) #start_time = datetime.datetime.now() for epoch in range(epoch_number): _, loss_value, step = sess.run( [train_op, loss, global_step], feed_dict={x: train_dataset, y: train_labels}) if epoch % steps_to_validate == 0: end_time = datetime.datetime.now() """ train_accuracy_value, summary_value = sess.run( [accuracy_op, summary_op], feed_dict={x: train_dataset, y: train_labels}) test_accuracy_value = sess.run( accuracy_op, feed_dict={x: test_dataset, y: test_labels}) logging.info( "[{}] Epoch: {}, loss: {}, train_accuracy: {}, test_accuracy: {}". format(end_time - start_time, epoch, loss_value, train_accuracy_value, test_accuracy_value)) """ logging.info("Epoch: {}, loss: {}".format(epoch, loss_value)) saver.save(sess, checkpoint_file, global_step=step) #writer.add_summary(summary_value, step) #start_time = end_time # Export the model export_path = os.path.join( # 将字节或unicode转换为字节 compat.as_bytes(FLAGS.model_path), compat.as_bytes(str(FLAGS.model_version))) logging.info("Export the model to {}".format(export_path)) try: # 初始化操作 # tf.group()用于创造一个操作,可以将传入参数的所有操作进行分组 legacy_init_op = tf.group( # tf.tables_initializer函数返回初始化所有表的操作。请注意,如果没有表格,则返回的操作是空操作 tf.tables_initializer(), name='legacy_init_op') # 这是第二种模型保存的办法,tensorflow serving,simple tensorflow serving # 需要这种模型保存的方法,即有模型参数和结构,还需要模型的 # input 和 output的信息 builder = saved_model_builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], clear_devices=True, signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature, }, legacy_init_op=legacy_init_op) builder.save() except Exception as e: logging.error("Fail to export saved model, exception: {}".format(e)) """ logging.info("Exporting trained model to {}".format(FLAGS.model_path)) model_exporter = exporter.Exporter(saver) model_exporter.init( sess.graph.as_graph_def(), named_graph_signatures={ 'inputs': exporter.generic_signature({ "keys": keys_placeholder, "features": x }), 'outputs': exporter.generic_signature({ "keys": keys, "prediction": predict_op }) }) model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess) logging.info("Done export model: {}".format(FLAGS.model_path)) """ elif mode == "inference": ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: logging.info("Load the model {}".format(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) start_time = datetime.datetime.now() image_ndarray = ndimage.imread(FLAGS.image, mode="RGB") # TODO: Update for server without gui #print_image(image_ndarray) image_ndarray = image_ndarray.reshape(1, IMAGE_SIZE, IMAGE_SIZE, RGB_CHANNEL_SIZE) prediction = sess.run(predict_op, feed_dict={x: image_ndarray}) end_time = datetime.datetime.now() pokemon_type = pokemon_types[prediction[0]] logging.info( "[{}] Predict type: {}".format(end_time - start_time, pokemon_type)) elif FLAGS.mode == "savedmodel": if restore_from_checkpoint(sess, saver, LATEST_CHECKPOINT) == False: logging.error("No checkpoint for exporting model, exit now") exit(1) export_path = os.path.join( compat.as_bytes(FLAGS.model_path), compat.as_bytes(str(FLAGS.model_version))) logging.info("Export the model to {}".format(export_path)) try: legacy_init_op = tf.group( tf.tables_initializer(), name='legacy_init_op') builder = saved_model_builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], clear_devices=True, signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature, }, legacy_init_op=legacy_init_op) builder.save() except Exception as e: logging.error("Fail to export saved model, exception: {}".format(e)) def print_image(image_ndarray): import matplotlib.pyplot as plt plt.imshow(image_ndarray) plt.show() def restore_from_checkpoint(sess, saver, checkpoint): if checkpoint: logging.info("Restore session from checkpoint: {}".format(checkpoint)) saver.restore(sess, checkpoint) return True else: logging.warn("Checkpoint not found: {}".format(checkpoint)) return False if __name__ == "__main__": main()