MiniFlow是TensorFlow的一个最小实现

网友投稿 824 2022-11-04 16:49:04

MiniFlow是TensorFlow的一个最小实现

MiniFlow

Introduction

MiniFlow is the numerical computation library which implements TensorFlow APIs.

Support math calculations and composited operationsSupport automatic partial derivative and chain ruleSupport operations in C++/Python backends with swigSupport platforms like Linux/MacOS/Windows/RaspbianSupport imperative and declarative computationsSupport the compatiable APIs with TensorFlow

Installation

Install with pip.

pip install miniflow

Or run with docker.

docker run -it tobegit3hub/miniflow bash

Usage

MiniFlow has compatiable APIs with TensorFlow and please refer to examples for more usage.

Basic operations

Run with TensorFlow.

import tensorflow as tfsess = tf.Session()hello = tf.constant("Hello, TensorFlow!")sess.run(hello)# "Hello, TensorFlow!"a = tf.constant(10)b = tf.constant(32)sess.run(a + b)# 42

Run with MiniFlow.

import miniflow as tfsess = tf.Session()hello = tf.constant("Hello, MiniFlow!")sess.run(hello)# "Hello, MiniFlow!"a = tf.constant(10)b = tf.constant(32)sess.run(a + b)# 42

Use placeholder

Run with TensorFlow.

import tensorflow as tfsess = tf.Session()a = tf.placeholder(tf.float32)b = tf.constant(32.0)sess.run(a + b, feed_dict={a: 10})sess.run(a + b, feed_dict={a.name: 10})# 42.0

Run with MiniFlow.

import miniflow as tfsess = tf.Session()a = tf.placeholder(tf.float32)b = tf.constant(32.0)sess.run(a + b, feed_dict={a: 10})sess.run(a + b, feed_dict={a.name: 10})# 42.0

Linear model

Run with TensorFlow.

def linear_regression(): epoch_number = 30 learning_rate = 0.01 train_features = [1.0, 2.0, 3.0, 4.0, 5.0] train_labels = [10.0, 20.0, 30.0, 40.0, 50.0] weights = tf.Variable(0.0) bias = tf.Variable(0.0) x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) predict = weights * x + bias loss = tf.square(y - predict) sgd_optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = sgd_optimizer.minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch_index in range(epoch_number): # Take one sample from train dataset sample_number = len(train_features) train_feature = train_features[epoch_index % sample_number] train_label = train_labels[epoch_index % sample_number] # Update model variables and print loss sess.run(train_op, feed_dict={x: train_feature, y: train_label}) loss_value = sess.run(loss, feed_dict={x: 1.0, y: 10.0}) print("Epoch: {}, loss: {}, weight: {}, bias: {}".format( epoch_index, loss_value, sess.run(weights), sess.run(bias)))

Run with MiniFlow.

def linear_regression(): epoch_number = 30 learning_rate = 0.01 train_features = [1.0, 2.0, 3.0, 4.0, 5.0] train_labels = [10.0, 20.0, 30.0, 40.0, 50.0] weights = tf.Variable(0.0) bias = tf.Variable(0.0) x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) predict = weights * x + bias loss = tf.square(y - predict) sgd_optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_op = sgd_optimizer.minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch_index in range(epoch_number): # Take one sample from train dataset sample_number = len(train_features) train_feature = train_features[epoch_index % sample_number] train_label = train_labels[epoch_index % sample_number] # Update model variables and print loss sess.run(train_op, feed_dict={x: train_feature, y: train_label}) loss_value = sess.run(loss, feed_dict={x: 1.0, y: 10.0}) print("Epoch: {}, loss: {}, weight: {}, bias: {}".format( epoch_index, loss_value, sess.run(weights), sess.run(bias)))

The computed gradient and the variables of the model are accurate.

Performance

We have more performance tests in benchmark.

Contribution

GitHub issues and pull-requests are highly appreciated and feel free to make your contribution.

Release to upload the official python package of miniflow in pypi.

python setup.py sdist uploadpython setup.py sdist --format=gztartwine upload dist/miniflow-x.x.x.tar.gz

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:分布式小文件存储系统
下一篇:一张图看懂IaaS, PaaS和SaaS的区别
相关文章