Skip to content

Commit

Permalink
fix pse_ctw1500 training error (#806)
Browse files Browse the repository at this point in the history
* fix pse_ctw1500 training error

* fix pse_ctw1500 training error

* fix pse_ctw1500 training error

* ci fix
  • Loading branch information
alien-0119 authored Feb 14, 2025
1 parent 8e81c14 commit fecd8ed
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
13 changes: 10 additions & 3 deletions mindocr/data/transforms/det_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import pyclipper
from shapely.geometry import Polygon, box

from ...utils.misc import is_uneven_nested_list

__all__ = [
"DetLabelEncode",
"BorderMap",
Expand Down Expand Up @@ -595,9 +597,14 @@ def _shrink(self, text_polys, rate, max_shr=20):
if not shrinked_bbox:
shrinked_text_polys.append(bbox)
continue

shrinked_bbox = np.array(shrinked_bbox)[0]
shrinked_bbox = np.array(shrinked_bbox)
if is_uneven_nested_list(shrinked_bbox):
shrinked_bbox = np.array(shrinked_bbox, dtype=object)[0]
else:
shrinked_bbox = np.array(shrinked_bbox)[0]
if is_uneven_nested_list(shrinked_bbox):
shrinked_bbox = np.array(shrinked_bbox, dtype=object)
else:
shrinked_bbox = np.array(shrinked_bbox)
if shrinked_bbox.shape[0] <= 2:
shrinked_text_polys.append(bbox)
continue
Expand Down
8 changes: 5 additions & 3 deletions mindocr/losses/det_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, alpha=0.7, ohem_ratio=3):
self.zeros_like = ops.ZerosLike()
self.add = ops.Add()
self.gather = ops.Gather()
self.upsample = nn.ResizeBilinear()
self.upsample = ops.interpolate

def ohem_batch(self, scores, gt_texts, training_masks):
"""
Expand Down Expand Up @@ -334,8 +334,10 @@ def construct(self, model_predict, gt_texts, gt_kernels, training_masks):
Tensor: The computed loss value.
"""
batch_size = model_predict.shape[0]
model_predict = self.upsample(model_predict, scale_factor=4)
h, w = model_predict.shape[2:]
scale_factor = 4
origin_h, origin_w = model_predict.shape[2:]
h, w = origin_h * scale_factor, origin_w * scale_factor
model_predict = self.upsample(model_predict, size=(h, w), mode="bilinear")
texts = self.slice(model_predict, (0, 0, 0, 0), (batch_size, 1, h, w))
texts = self.reshape(texts, (batch_size, h, w))
selected_masks_text = self.ohem_batch(texts, gt_texts, training_masks)
Expand Down
15 changes: 2 additions & 13 deletions mindocr/postprocess/det_db_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mindspore import Tensor

from ..data.transforms.det_transforms import expand_poly
from ..utils.misc import is_uneven_nested_list
from .det_base_postprocess import DetBasePostprocess

__all__ = ["DBPostprocess"]
Expand Down Expand Up @@ -111,7 +112,7 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):

poly = Polygon(points)
poly_list = expand_poly(points, distance=poly.area * self._expand_ratio / poly.length)
if self._is_uneven_nested_list(poly_list):
if is_uneven_nested_list(poly_list):
poly = np.array(poly_list, dtype=object)
else:
poly = np.array(poly_list)
Expand All @@ -138,18 +139,6 @@ def _extract_preds(self, pred: np.ndarray, bitmap: np.ndarray):
return polys, scores
return np.array(polys), np.array(scores).astype(np.float32)

def _is_uneven_nested_list(self, arr_list):
if not isinstance(arr_list, list):
return False

first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None

for sublist in arr_list:
if not isinstance(sublist, list) or len(sublist) != first_length:
return True

return False

@staticmethod
def _fit_box(contour):
"""
Expand Down
13 changes: 13 additions & 0 deletions mindocr/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,16 @@ def is_ms_version_2():
make compatibilities in differenct Mindspore version
"""
return version.parse(ms.__version__) >= version.parse("2.0.0rc")


def is_uneven_nested_list(arr_list):
if not isinstance(arr_list, list):
return False

first_length = len(arr_list[0]) if isinstance(arr_list[0], list) else None

for sublist in arr_list:
if not isinstance(sublist, list) or len(sublist) != first_length:
return True

return False

0 comments on commit fecd8ed

Please sign in to comment.