Numpy intersect1d with array with matrix as elements


I have two arrays, one of the shape (200000, 28, 28) and the other of the shape (10000, 28, 28), so practically two arrays with matrices as elements. Now I want to count and get all the elements (in the form (N, 28, 28)), that overlap in both arrays. With normal for loops it is way to slow, so I tryied it with numpys intersect1d method, but I dont know how to apply it on this types of arrays.

Show source
| python   | arrays   | numpy   2017-01-01 16:01 1 Answers

Answers ( 1 )

  1. 2017-01-01 18:01

    Using the approach from this question about unique rows

    def intersect_along_first_axis(a, b):
        # check that casting to void will create equal size elements
        assert a.shape[1:] == b.shape[1:]
        assert a.dtype == b.dtype
        # compute dtypes
        void_dt = np.dtype((np.void, a.dtype.itemsize *[1:])))
        orig_dt = np.dtype((a.dtype, a.shape[1:]))
        # convert to 1d void arrays
        a = np.ascontiguousarray(a)
        b = np.ascontiguousarray(b)
        a_void = a.reshape(a.shape[0], -1).view(void_dt)
        b_void = b.reshape(b.shape[0], -1).view(void_dt)
        # intersect, then convert back
        return np.intersect1d(b_void, a_void).view(orig_dt)

    Note that using void is unsafe with floats, as it will cause -0 to be unequal to 0

◀ Go back