Issue |
138739
|
Summary |
Tf.MaxPool3D not supported
|
Labels |
|
Assignees |
|
Reporter |
GiuseppeSorrentino99
|
Hello, I am trying to convert a tf network in TOSA, but seems one of the layers is not supported:
```
output/tosa.mlir:46:25: error: operation being parsed with an unregistered dialect. If this is intended, please use -allow-unregistered-dialect with the MLIR tool used
%43 = "tf.MaxPool3D"(%42) {data_format = "NDHWC", device = "", ksize = [1, 2, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<1x128x128x128x16xf32>) -> tensor<1x64x64x64x16xf32>
```
As the Conv3D is supported, is there a workaround for this problem?
I attach here also the code for both the network and the sets of command used for reproducing the error:
**NN**
```
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow_addons as tfa
class SpatialTransformer(layers.Layer):
"""3D Spatial Transformer using batched 2D warps and static shape enforcement."""
def call(self, inputs):
vol, flow = inputs # vol: [B,D,H,W,C], flow: [B,D,H,W,3]
# 1. Enforce static (non-zero) shapes to satisfy TOSA requirements
# (TOSA dialect expects all dims ≥ 1 and statically known)
vol = tf.ensure_shape(vol, [None, vol.shape[1], vol.shape[2], vol.shape[3], vol.shape[4]])
flow = tf.ensure_shape(flow, [None, flow.shape[1], flow.shape[2], flow.shape[3], 3])
# 2. Flatten depth dimension into batch: [B,D,H,W,C] → [B*D,H,W,C]
shape = tf.shape(vol)
B, D, H, W, C = shape[0], shape[1], shape[2], shape[3], vol.shape[4]
vol_flat = tf.reshape(vol, tf.stack([B * D, H, W, C]))
flow_flat = tf.reshape(flow, tf.stack([B * D, H, W, C]))
# 3. Perform a single batched 2D warp via dense_image_warp,
# avoiding tf.map_fn loops entirely
moved_flat = tfa.image.dense_image_warp(vol_flat, flow_flat[..., :2])
# 4. Restore original shape: [B*D,H,W,C] → [B,D,H,W,C]
moved = tf.reshape(moved_flat, tf.stack([B, D, H, W, C]))
return moved
def conv_block(x, filters, convs=2, kernel_size=3, activation='relu'):
for _ in range(convs):
x = layers.Conv3D(filters, kernel_size, padding='same',
kernel_initializer='he_normal')(x)
x = layers.Activation(activation)(x)
return x
def build_minimal_voxelmorph(inshape,
enc_features=(16, 32, 32, 32),
dec_features=(32, 32, 32, 32, 32, 16, 16)):
moving = layers.Input(shape=(*inshape, 1), name='moving')
fixed = layers.Input(shape=(*inshape, 1), name='fixed')
x = layers.Concatenate(axis=-1)([moving, fixed])
skips = []
for f in enc_features:
x = conv_block(x, f)
skips.append(x)
x = layers.MaxPool3D(2)(x)
```
which is converted in tosa through:
x = conv_block(x, enc_features[-1] * 2)
for f, skip in zip(dec_features, reversed(skips)):
x = layers.UpSampling3D(2)(x)
x = layers.Concatenate(axis=-1)([x, skip])
x = conv_block(x, f)
flow = layers.Conv3D(3, 3, padding='same', name='flow')(x)
moved = SpatialTransformer(name='moved')([moving, flow])
return models.Model(inputs=[moving, fixed],
outputs=[moved, flow],
name='VoxelmorphMinimalFlatten')
# Instantiate model for a 128³ volume
model = build_minimal_voxelmorph((128, 128, 128))
model.summary()
traduced through the following commands
```
docker run -u $(id -u):$(id -g) -v $(pwd):/working_dir --rm agostini01/soda \
tf-mlir-translate \
--graphdef-to-mlir \
--tf-input-arrays=fixed,moving \
--tf-input-data-types=DT_FLOAT,DT_FLOAT \
--tf-input-shapes=1,128,128,128,1:1,128,128,128,1 \
--tf-output-arrays=Identity,Identity_1 \
$1 \
-o output/tf.mlir
docker run -u $(id -u):$(id -g) -v $(pwd):/working_dir --rm agostini01/soda \
tf-opt \
--tf-executor-to-functional-conversion \
--tf-region-control-flow-to-functional \
--tf-shape-inference \
--tf-to-tosa-pipeline \
output/tf.mlir \
-o $2
```
In practice, what happens is that, while most of the operations are supported and traduced, the tf.MaxPool3D is not. Also, looking at the supported operators, it seems not supported. Thus I am wondering if there is a solution for this.
Thanks in advance for any support.
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs