ImageAugmenter only for RGB and uchar
crohkohl opened this issue · comments
Christopher commented
Hi,
during my testing I tried to use image augmenter with different data, e.g.2x128x128 oder 10x128x128.
This is currently not supported and it is always assumed that the data is uchar.
Here is my proposed enhancement for image_augmenter-inl.hpp:
std::vector<cv::Mat> Process(const std::vector<cv::Mat> &srcMats, utils::RandomSampler *prnd)
{
std::vector<cv::Mat> resMats;
// shear
float s = prnd->NextDouble() * max_shear_ratio_ * 2 - max_shear_ratio_;
// rotate
int angle = prnd->NextUInt32(max_rotate_angle_ * 2) - max_rotate_angle_;
if (rotate_ > 0) angle = rotate_;
if (rotate_list_.size() > 0) {
angle = rotate_list_[prnd->NextUInt32(rotate_list_.size() - 1)];
}
float a = cos(angle / 180.0 * M_PI);
float b = sin(angle / 180.0 * M_PI);
// scale
float scale = prnd->NextDouble() * (max_random_scale_ - min_random_scale_) + min_random_scale_;
// aspect ratio
float ratio = prnd->NextDouble() * max_aspect_ratio_ * 2 - max_aspect_ratio_ + 1;
float hs = 2 * scale / (1 + ratio);
float ws = ratio * hs;
// new width and height
float new_width = std::max(min_img_size_, std::min(max_img_size_, scale * srcMats[0].cols));
float new_height = std::max(min_img_size_, std::min(max_img_size_, scale * srcMats[0].rows));
//printf("%f %f %f %f %f %f %f %f %f\n", s, a, b, scale, ratio, hs, ws, new_width, new_height);
cv::Mat M(2, 3, CV_32F);
M.at<float>(0, 0) = hs * a - s * b * ws;
M.at<float>(1, 0) = -b * ws;
M.at<float>(0, 1) = hs * b + s * a * ws;
M.at<float>(1, 1) = a * ws;
float ori_center_width = M.at<float>(0, 0) * srcMats[0].cols + M.at<float>(0, 1) * srcMats[0].rows;
float ori_center_height = M.at<float>(1, 0) * srcMats[0].cols + M.at<float>(1, 1) * srcMats[0].rows;
M.at<float>(0, 2) = (new_width - ori_center_width) / 2;
M.at<float>(1, 2) = (new_height - ori_center_height) / 2;
for (int iz=0; iz<srcMats.size(); iz++)
{
cv::Mat tmp;
cv::warpAffine(srcMats[iz], tmp, M, cv::Size(new_width, new_height), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0.0f, 0.0f, 0.0f));
resMats.push_back(tmp);
}
mshadow::index_t y = resMats[0].rows - shape_[2];
mshadow::index_t x = resMats[0].cols - shape_[1];
if (rand_crop_ != 0) {
y = prnd->NextUInt32(y + 1);
x = prnd->NextUInt32(x + 1);
} else {
y /= 2; x /= 2;
}
cv::Rect roi(x, y, shape_[1], shape_[2]);
for (int iz=0; iz<resMats.size(); iz++)
resMats[iz] = resMats[iz](roi);
return resMats;
}
virtual mshadow::Tensor<cpu, 3> Process(mshadow::Tensor<cpu, 3> data,
utils::RandomSampler *prnd) {
if (!NeedProcess()) return data;
std::vector<cv::Mat> resMats;
for (index_t k=0; k<data.size(0); k++)
{
cv::Mat res(data.size(1), data.size(2), CV_32FC1);
for (index_t i = 0; i < data.size(1); ++i)
for (index_t j = 0; j < data.size(2); ++j)
res.at<float>(i, j) = data[k][i][j];
resMats.push_back(res);
}
resMats = Process(resMats, prnd);
tmpres.Resize(mshadow::Shape3(resMats.size(), resMats[0].rows, resMats[0].cols));
for (index_t k=0; k<resMats.size(); k++)
{
for (index_t i = 0; i < tmpres.size(1); ++i)
for (index_t j = 0; j < tmpres.size(2); ++j)
tmpres[k][i][j] = resMats[k].at<float>(i, j);
}
return tmpres;
}