forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbackend.cs
More file actions
114 lines (102 loc) · 4.69 KB
/
backend.cs
File metadata and controls
114 lines (102 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.CompilerServices;
using static Tensorflow.Python;
namespace Tensorflow.Keras
{
public class backend : BackendBase
{
/* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
public static Func<Array, double> py_sum = sum;
public static Func<Array, bool> py_all = all;
//Func<Array, bool> py_any = any;
//Func<double, double, double, IEnumerable<double>> py_slice = slice;
public static Session _SESSION = Tensorflow.tf.defaultSession;
public static Graph _GRAPH = null;
public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
public static bool _MANUAL_VAR_INIT = false;
public static List<string> _LOCAL_DEVICES = null;
/* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */
/// <summary>
/// A global dictionary mapping graph objects to an index of counters used
/// for various layer names in each graph.
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
/// </summary>
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>();
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();
public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();
public static void track_variable(RefVariable v)
{
var graph = v.graph;
_GRAPH_VARIABLES[graph.graph_key] = v;
}
public static Tensor placeholder(int[] shape = null,
int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
string name = null)
{
if (sparse)
{
throw new NotImplementedException("placeholder sparse is true");
}
else
{
return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name);
}
}
public static Graph get_graph()
{
return ops.get_default_graph();
}
public static int get_uid(string prefix, string @namespace = "")
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
}
public static int get_uid((string, string) name)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;
return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
}
public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static void clear_session()
{
ops.reset_default_graph();
reset_uids();
_SESSION = null;
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
}
public static void manual_variable_initialization(bool value)
{
_MANUAL_VAR_INIT = value;
}
public static GraphLearningPhase learning_phase()
{
var graph = tf.get_default_graph();
if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
{
var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES[graph] = 0;
}
return _GRAPH_LEARNING_PHASES[graph];
}
public static void set_learning_phase(bool value)
{
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}
public class _DummyEagerGraph
{ }
}
}