Google has published their quantization method on this paper. It use int8 to run feed-forward but float32 for back-propagation, since back-propagation need more accurate to accumulate gradients. I got a question right after reading the paper: why all the performance test works are on platform of mobile-phone (ARM architecture)? The quantization consequences of model in google’s method doesn’t only need addition and multiplication of int8 numbers, but also bit-shift operations. The AVX instruments set in Intel x86_64 architecture could accelerate MAC (Multiplication, Addition and aCcumulation), but couldn’t boost bit-shift operations.
To verify my suspicion, I wrote a model with ResNet-50 (float32) to classify CIFAR-100 dataset. After running a few epochs, I evaluate the speed of inference by using my ‘eval.py’. The result is:
1 |
Time: 5.58819s |
Then, I follow these steps to add tf.contrib.quantize.create_training_graph() and tf.contrib.quantize.create_eval_graph() into my code. This time, the speed of inference is:
1 |
Time: 6.23221s |
A little bit of disappointment. Using quantized (int8) version of model could not accelerate processing speed of x86 CPU. May be we need to find other more powerful quantization algorithm.
Appendix:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# eval.py from input_data import Cifar100Data import tensorflow as tf import numpy as np import resnet_v2 import argparse import time import sys EVAL_SAMPLES = 10000 BATCH_SIZE = 10000 MODEL_PATH = './models/' MODEL_NAME = 'cifar_resnet_50' def cnn_part(images): print(images.shape) ivg, _ = resnet_v2.resnet_v2_50(images, 100) return ivg def main(_): with tf.device('/cpu:0'): images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3]) labels = tf.placeholder(tf.int64, [BATCH_SIZE]) with tf.contrib.slim.arg_scope([tf.contrib.slim.conv2d], weights_initializer = tf.truncated_normal_initializer(mean = 0, stddev = 0.1)): image_vector = cnn_part(images) loss = tf.losses.sparse_softmax_cross_entropy(labels = labels, logits = image_vector) loss = tf.reduce_mean(loss) opt = tf.train.AdamOptimizer(1e-3) train_op = tf.contrib.slim.learning.create_train_op(loss, opt) correct_prediction = tf.equal(tf.argmax(image_vector, 1), labels) correct_prediction = tf.cast(correct_prediction, tf.float32) accuracy = tf.reduce_mean(correct_prediction) data = Cifar100Data('/disk3/cifar/cifar-100-python/test') saver = tf.train.Saver() with tf.Session() as sess: with tf.gfile.FastGFile('./models/cifar_resnet_50_quant.pb') as fl: graph_def = tf.GraphDef() graph_def.ParseFromString(fl.read()) tf.import_graph_def(graph_def, name = '') saver.restore(sess, MODEL_PATH + MODEL_NAME + '-' + str(FLAGS.epoch)) batch = data.next_batch(BATCH_SIZE) for i in range(3): begin = time.time() res = sess.run(accuracy, feed_dict = {images: batch[0], labels: batch[1]}) print("Time: %gs" % (time.time() - begin)) print(res) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=str, default='8', help='Epoch of checkpoint for evaluation') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main = main, argv = [sys.argv[0]] + unparsed) |