cryptic
10/28/2022, 10:25 AMmake lint
? I'm sorry but I'm not quite sure about it (https://github.com/flyteorg/flytekit/actions/runs/3326738618/jobs/5511140308#step:6:360)Samhita Alla
pre-commit install
in your flytekit repo locally? That should automatically warn you about lint errors during the commit.cryptic
10/28/2022, 10:36 AMSamhita Alla
cryptic
10/28/2022, 12:29 PMdef test_to_python_value_and_literal(transformer, python_type, format, python_val):
ctx = context_manager.FlyteContext.current_context()
tf = transformer
lt = tf.get_literal_type(python_type)
lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
> assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
E AttributeError: 'NoneType' object has no attribute 'blob'
Samhita Alla
cryptic
10/28/2022, 12:53 PMtest_to_python_value_and_literal
?Samhita Alla
cryptic
10/28/2022, 12:54 PMa = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
result_tensor = tf.matmul(a, b)
@pytest.mark.parametrize(
"transformer,python_type,format,python_val",
[
(
TensorFlowTensorTransformer(),
tf.Tensor,
TensorFlowTensorTransformer.TENSORFLOW_FORMAT,
result_tensor,
)
],
)
def test_to_python_value_and_literal(transformer, python_type, format, python_val):
ctx = context_manager.FlyteContext.current_context()
tf = transformer
lt = tf.get_literal_type(python_type)
lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None
output = tf.to_python_value(ctx, lv, python_val, python_type)
assert (output.numpy() == python_val.numpy()).all()
Samhita Alla
to_literal
?cryptic
10/28/2022, 12:57 PMdef to_literal(
self,
ctx: FlyteContext,
python_val: tf.Tensor,
python_type: Type[tf.Tensor],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)
local_path = ctx.file_access.get_random_local_path()
#pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
# Save the `tf.tensor` as a file on disk
#global local_file_path
local_path = os.path.join(local_path, "tensor_data")
tf.io.write_file(local_path, tf.io.serialize_tensor(python_val))
tensor_dtype = python_val.dtype.name
remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(
collection=LiteralCollection(
literals=[
Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))),
Literal(scalar=Scalar(primitive=Primitive(string_value=tensor_dtype))),
]
)
)
Samhita Alla
cryptic
10/28/2022, 12:59 PM