25 using namespace tensorflow;
27 const Tensor& inp_tensor =
context->input(0);
30 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
31 errors::InvalidArgument(
"RoiPool expects "
32 "(batch_size,num_points,3) inp shape"));
33 int batch_size = inp_tensor.shape().dim_size(0);
34 int pts_num = inp_tensor.shape().dim_size(1);
35 auto inp_flat = inp_tensor.flat<
float>();
36 const float* inp = &(inp_flat(0));
38 const Tensor& boxes3d_tensor =
context->input(1);
40 boxes3d_tensor.dims() == 3 &&
41 boxes3d_tensor.shape().dim_size(2) == 7,
42 errors::InvalidArgument(
44 "(batch_size,num_boxes,7) boxes3d shape"));
45 int boxes_num = boxes3d_tensor.shape().dim_size(1);
46 auto boxes3d_flat = boxes3d_tensor.flat<
float>();
47 const float* boxes3d = &(boxes3d_flat(0));
49 const Tensor& feats_tensor =
context->input(2);
51 feats_tensor.dims() == 3 &&
52 feats_tensor.shape().dim_size(1) == pts_num,
53 errors::InvalidArgument(
55 "(batch_size,num_points,feats) feats shape"));
56 int feature_in_len = feats_tensor.shape().dim_size(2);
57 auto feats_flat = feats_tensor.flat<
float>();
58 const float* feats = &(feats_flat(0));
64 TensorShape{batch_size, boxes_num,
65 sampled_pts_num, 3 + feature_in_len},
67 auto out_flat0 = out_feats->flat<
float>();
68 float* out0 = &(out_flat0(0));
72 1, TensorShape{batch_size, boxes_num},
74 auto out_flat1 = out_flags->flat<
int>();
75 int* out1 = &(out_flat1(0));
77 Kernel(
context, batch_size, pts_num, boxes_num, feature_in_len,