生成SparseTensor需要一段时间
#dense是一个n x m矩阵
sparse = coo_matrix(密集)#几乎是瞬间的
#易读性sparse_indicies = list(zip( sparse.row.astype(…
您可以尝试制作TensorFlow常量并将它们存储在一个 GraphDef 文件,然后加载它并在需要时将它们导入图表。我不知道这是否会比你现在的方法更快。
GraphDef
要将常量导出到文件,您可以执行以下操作:
import tensorflow as tf # In an independent graph to make sure only the data we want is stored with tf.Graph().as_default(): sparse = coo_matrix(dense) sparse_indicies = list(zip( sparse.row.astype(np.int64).tolist(), sparse.col.astype(np.int64).tolist() )) type_casted = (sparse.data).astype(np.float32) # Make TensorFlow constants indices = tf.constant(sparse_indicies, name='Indices', dtype=tf.int64) values = tf.constant(type_casted, name='Values') shape = tf.constant(sparse.shape, dtype=tf.int64, name='Shape') # Serialize graph graph_def = tf.get_default_graph().as_graph_def() with open('sparse_tensor_data.pb', 'wb') as f: f.write(graph_def.SerializeToString())
您可以从其他地方导入它,如下所示:
import tensorflow as tf # Read graph graph_def = tf.GraphDef() with open('sparse_tensor_data.pb', 'rb') as f: graph_def.MergeFromString(f.read()) # Import tensors indices, values, shape = tf.import_graph_def( graph_def, return_elements=['Indices:0', 'Values:0', 'Shape:0'], name='SparseTensorImport') del graph_def # Create sparse tensor input_tensor = tf.SparseTensor(indices=indices, values=values, dense_shape=shape)