From 27539f081bd6da80896a8787d9deb967ea484099 Mon Sep 17 00:00:00 2001 From: Peter Boin Date: Fri, 4 Jan 2019 00:29:48 +1100 Subject: [PATCH] Matrix validation robuustness + tests --- cadquery/occ_impl/geom.py | 37 +++++++++++++++++++------------------ tests/TestCadObjects.py | 21 ++++++++++++++++++--- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/cadquery/occ_impl/geom.py b/cadquery/occ_impl/geom.py index 9592f103..470daae7 100644 --- a/cadquery/occ_impl/geom.py +++ b/cadquery/occ_impl/geom.py @@ -215,20 +215,21 @@ class Matrix: self.wrapped = gp_Trsf() elif isinstance(matrix, gp_Trsf): self.wrapped = matrix - elif isinstance(matrix, list): + elif isinstance(matrix, (list, tuple)): + # Validate matrix size & 4x4 last row value + valid_sizes = all( + (isinstance(row, (list, tuple)) and (len(row) == 4)) + for row in matrix + ) and len(matrix) in (3, 4) + if not valid_sizes: + raise TypeError("Matrix constructor requires 2d list of 4x3 or 4x4, but got: {!r}".format(matrix)) + elif (len(matrix) == 4) and (tuple(matrix[3]) != (0,0,0,1)): + raise ValueError("Expected the last row to be [0,0,0,1], but got: {!r}".format(matrix[3])) + + # Assign values to matrix self.wrapped = gp_Trsf() - if len(matrix) == 3: - flattened = [e for row in matrix for e in row] - self.wrapped.SetValues(*flattened) - elif len(matrix) == 4: - # Only use first 3 rows - the last must be [0, 0, 0, 1]. - lastRow = matrix[3] - if lastRow != [0., 0., 0., 1.]: - raise ValueError("Expected the last row to be [0,0,0,1], but got: {}".format(lastRow)) - flattened = [e for row in matrix[0:3] for e in row] - self.wrapped.SetValues(*flattened) - else: - raise TypeError("Matrix constructor requires list of length 12 or 16") + flattened = [e for row in matrix[:3] for e in row] + self.wrapped.SetValues(*flattened) else: raise TypeError( "Invalid param to matrix constructor: {}".format(matrix)) @@ -282,18 +283,18 @@ class Matrix: and column parameters start at zero, which is consistent with most python libraries, but is counter to gp_Trsf(), which is 1-indexed. """ - if len(rc) != 2: + if not isinstance(rc, tuple) or (len(rc) != 2): raise IndexError("Matrix subscript must provide (row, column)") - r, c = rc[0], rc[1] - if r >= 0 and r < 4 and c >= 0 and c < 4: + (r, c) = rc + if (0 <= r <= 3) and (0 <= c <= 3): if r < 3: - return self.wrapped.Value(r+1,c+1) + return self.wrapped.Value(r + 1, c + 1) else: # gp_Trsf doesn't provide access to the 4th row because it has # an implied value as below: return [0., 0., 0., 1.][c] else: - raise IndexError("Out of bounds access into 4x4 matrix: {}".format(rc)) + raise IndexError("Out of bounds access into 4x4 matrix: {!r}".format(rc)) class Plane(object): diff --git a/tests/TestCadObjects.py b/tests/TestCadObjects.py index aa5c4ad4..3c112790 100644 --- a/tests/TestCadObjects.py +++ b/tests/TestCadObjects.py @@ -166,14 +166,19 @@ class TestCadObjects(BaseTest): [0., 1., 0., 2.], [0., 0., 1., 3.], [0., 0., 0., 1.]] + vals4x4_tuple = tuple(tuple(r) for r in vals4x4) # test constructor with 16-value input m = Matrix(vals4x4) self.assertEqual(vals4x4, matrix_vals(m)) + m = Matrix(vals4x4_tuple) + self.assertEqual(vals4x4, matrix_vals(m)) # test constructor with 12-value input (the last 4 are an implied # [0,0,0,1]) - m = Matrix(vals4x4[0:12]) + m = Matrix(vals4x4[:3]) + self.assertEqual(vals4x4, matrix_vals(m)) + m = Matrix(vals4x4_tuple[:3]) self.assertEqual(vals4x4, matrix_vals(m)) # Test 16-value input with invalid values for the last 4 @@ -184,14 +189,24 @@ class TestCadObjects(BaseTest): with self.assertRaises(ValueError): Matrix(invalid) - # Test input with invalid size + # Test input with invalid size / nested types + with self.assertRaises(TypeError): + Matrix([[1, 2, 3, 4], [1, 2, 3], [1, 2, 3, 4]]) with self.assertRaises(TypeError): Matrix([1,2,3]) + # Invalid sub-type + with self.assertRaises(TypeError): + Matrix([[1, 2, 3, 4], 'abc', [1, 2, 3, 4]]) + # test out-of-bounds access m = Matrix() with self.assertRaises(IndexError): - m[5, 5] + m[0, 4] + with self.assertRaises(IndexError): + m[4, 0] + with self.assertRaises(IndexError): + m['ab'] def testTranslate(self):