Skip to content

Commit f849af0

Browse files
committed
Change RefVariable to IVariableV1.
1 parent 7d7a317 commit f849af0

File tree

8 files changed

+49
-49
lines changed

8 files changed

+49
-49
lines changed

src/TensorFlowNET.Core/APIs/tf.state.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Tensorflow
1818
{
1919
public partial class tensorflow
2020
{
21-
public Tensor assign_add<T>(RefVariable @ref, T value,
21+
public Tensor assign_add<T>(IVariableV1 @ref, T value,
2222
bool use_locking = false, string name = null)
2323
=> state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
2424
}

src/TensorFlowNET.Core/Training/Optimizer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public Optimizer(Tensor learning_rate, bool use_locking, string name = null)
106106
/// was not `None`, that operation also increments `global_step`.
107107
/// </returns>
108108
public Operation minimize(Tensor loss,
109-
RefVariable global_step = null,
109+
IVariableV1 global_step = null,
110110
List<ResourceVariable> var_list=null,
111111
GateGradientType gate_gradients = GateGradientType.GATE_OP,
112112
int? aggregation_method=null,
@@ -142,7 +142,7 @@ public Operation minimize(Tensor loss,
142142
/// <returns>
143143
/// An `Operation` that applies the specified gradients. If `global_step`
144144
/// was not None, that operation also increments `global_step`.</returns>
145-
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null)
145+
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, IVariableV1 global_step = null, string name = null)
146146
{
147147
// No DistributionStrategy case.
148148
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>();
@@ -192,7 +192,7 @@ public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_var
192192
{
193193
tf_with(ops.control_dependencies(new object[] {_finish(update_ops.ToArray(), "update")}), dep =>
194194
{
195-
ops.colocate_with(global_step);
195+
// ops.colocate_with(global_step);
196196
// TODO: port this if branch once ResourceVariable has been ported!
197197
//if (global_step is ResourceVariable)
198198
//{

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,28 @@ Tensor read_value()
122122
return array_ops.identity(value);
123123
});
124124

125+
public Operation assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
126+
{
127+
var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
128+
ops.convert_to_tensor(delta, dtype: dtype), name: name);
129+
130+
/*if (read_value)
131+
return _lazy_read(assign_add_op);*/
132+
return assign_add_op;
133+
}
134+
125135
public override string ToString()
126-
=> $"tf.Variable '{Name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}";
136+
{
137+
if (tf.context.executing_eagerly())
138+
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}";
139+
else
140+
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
141+
}
127142

128143
public NDArray numpy() => read_value().numpy();
129144

130145
protected override void DisposeUnmanagedResources(IntPtr handle)
131146
{
132-
// delete
133-
// c_api.TFE_DeleteResourceVariable(handle);
134147
}
135148
}
136149
}

src/TensorFlowNET.Core/Variables/IVariableV1.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,7 @@ public interface IVariableV1
3737
public Operation Op { get; }
3838
public Tensor GraphElement { get; }
3939
public Graph Graph { get; }
40+
public TF_DataType dtype { get; }
41+
public Operation assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
4042
}
4143
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,5 +401,26 @@ public Tensor initialized_value()
401401
read_value,
402402
initial_value);
403403
}
404+
405+
// Update 'ref' by adding 'value' to it.
406+
// This operation outputs "ref" after the update is done.
407+
// This makes it easier to chain operations that need to use the reset value.
408+
// Args:
409+
// ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
410+
// Should be from a `Variable` node.
411+
// value: A `Tensor`. Must have the same type as `ref`.
412+
// The value to be added to the variable.
413+
// use_locking: An optional `bool`. Defaults to `False`.
414+
// If True, the addition will be protected by a lock;
415+
// otherwise the behavior is undefined, but may exhibit less contention.
416+
// name: A name for the operation(optional).
417+
// Returns:
418+
// A mutable `Tensor`. Has the same type as `ref`.
419+
public Operation assign_add<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
420+
{
421+
var variable = this;
422+
var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking });
423+
return _op;
424+
}
404425
}
405426
}

src/TensorFlowNET.Core/Variables/ResourceVariable.cs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,17 @@ private void _init_from_args(object initial_value = null,
139139
tf_with(ops.name_scope("Assign"), scope1 =>
140140
{
141141
string n = scope1;
142-
initializer_op = gen_resource_variable_ops.assign_variable_op(handle,
143-
variables._try_guard_against_uninitialized_dependencies(name, _initial_value),
144-
name: n);
142+
var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value);
143+
initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n);
145144
});
146145
}
147146

148147
// Manually assign reads to the handle's device to avoid log
149148
// messages.
150149
tf_with(ops.name_scope("Read"), delegate
151150
{
152-
var value = _read_variable_op();
151+
var value = gen_resource_variable_ops.read_variable_op(handle, _dtype);
152+
// _maybe_set_handle_data(dtype, handle, value);
153153
_graph_element = value;
154154
});
155155

@@ -233,16 +233,5 @@ public Tensor sparse_read(Tensor indices, string name = "Gather")
233233
return array_ops.identity(value);
234234
});
235235
}
236-
237-
public override string ToString()
238-
{
239-
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}";
240-
}
241-
242-
protected override void DisposeUnmanagedResources(IntPtr handle)
243-
{
244-
// delete
245-
// c_api.TFE_DeleteResourceVariable(handle);
246-
}
247236
}
248237
}

src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,6 @@ public static Tensor assign_sub(RefVariable @ref,
120120
return _op.outputs[0];
121121
}
122122

123-
124-
// Update 'ref' by adding 'value' to it.
125-
// This operation outputs "ref" after the update is done.
126-
// This makes it easier to chain operations that need to use the reset value.
127-
// Args:
128-
// ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
129-
// Should be from a `Variable` node.
130-
// value: A `Tensor`. Must have the same type as `ref`.
131-
// The value to be added to the variable.
132-
// use_locking: An optional `bool`. Defaults to `False`.
133-
// If True, the addition will be protected by a lock;
134-
// otherwise the behavior is undefined, but may exhibit less contention.
135-
// name: A name for the operation(optional).
136-
// Returns:
137-
// A mutable `Tensor`. Has the same type as `ref`.
138-
public static Tensor assign_add<T>(RefVariable @ref, T value, bool use_locking = false, string name = null)
139-
{
140-
var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking });
141-
return _op.outputs[0];
142-
}
143-
144123
/// <summary>
145124
/// Adds sparse updates to a variable reference.
146125
/// </summary>

src/TensorFlowNET.Core/Variables/state_ops.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,11 @@ public static Tensor assign_sub(RefVariable @ref,
106106
// Returns:
107107
// Same as "ref". Returned as a convenience for operations that want
108108
// to use the new value after the variable has been updated.
109-
public static Tensor assign_add<T>(RefVariable @ref,
109+
public static Operation assign_add<T>(IVariableV1 @ref,
110110
T value,
111111
bool use_locking = false,
112112
string name = null)
113-
{
114-
if (@ref.dtype.is_ref_dtype())
115-
return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
116-
throw new NotImplementedException("assign_add");
117-
}
113+
=> @ref.assign_add(value, use_locking: use_locking, name: name);
118114

119115
public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
120116
{

0 commit comments

Comments
 (0)