Commit 3e8b0e26 authored by mattip's avatar mattip

hack at strides and shapes till external loops can be used

parent 0ded1722
......@@ -449,7 +449,7 @@ class SliceArray(BaseConcreteArray):
strides.reverse()
backstrides.reverse()
new_shape.reverse()
return SliceArray(self.start, strides, backstrides, new_shape,
return type(self)(self.start, strides, backstrides, new_shape,
self, orig_array)
new_strides = calc_new_strides(new_shape, self.get_shape(),
self.get_strides(),
......@@ -460,10 +460,16 @@ class SliceArray(BaseConcreteArray):
new_backstrides = [0] * len(new_shape)
for nd in range(len(new_shape)):
new_backstrides[nd] = (new_shape[nd] - 1) * new_strides[nd]
return SliceArray(self.start, new_strides, new_backstrides, new_shape,
return type(self)(self.start, new_strides, new_backstrides, new_shape,
self, orig_array)
class NonWritableSliceArray(SliceArray):
def descr_setitem(self, space, orig_array, w_index, w_value):
raise OperationError(space.w_ValueError, space.wrap(
"assignment destination is read-only"))
class VoidBoxStorage(BaseConcreteArray):
def __init__(self, size, dtype):
self.storage = alloc_raw_storage(size)
......
......@@ -8,8 +8,8 @@ Given an array x: x.shape == [5,6], where each element occupies one byte
At which byte in x.data does the item x[3,4] begin?
if x.strides==[1,5]:
pData = x.pData + (x.start + 3*1 + 4*5)*sizeof(x.pData[0])
pData = x.pData + (x.start + 24) * sizeof(x.pData[0])
so the offset of the element is 24 elements after the first
pData = x.pData + (x.start + 23) * sizeof(x.pData[0])
so the offset of the element is 23 elements after the first
What is the next element in x after coordinates [3,4]?
if x.order =='C':
......@@ -33,7 +33,7 @@ shape dimension
which is x.strides[1] * (x.shape[1] - 1) + x.strides[0]
so if we precalculate the overflow backstride as
[x.strides[i] * (x.shape[i] - 1) for i in range(len(x.shape))]
we can go faster.
we can do only addition while iterating
All the calculations happen in next()
"""
from rpython.rlib import jit
......@@ -208,6 +208,12 @@ class ArrayIter(object):
assert state.iterator is self
self.array.setitem(state.offset, elem)
def getoperand(self, st, base):
impl = self.operand_type
res = impl([], self.array.dtype, self.array.order, [], [],
self.array.storage, base)
res.start = st.offset
return res
def AxisIter(array, shape, axis, cumulative):
strides = array.get_strides()
......@@ -238,18 +244,26 @@ class SliceIter(ArrayIter):
view into the original array
'''
def __init__(self, array, size, shape, strides, backstrides, op_flags):
def __init__(self, array, size, shape, strides, backstrides, slice_shape,
slice_stride, slice_backstride, op_flags, base):
from pypy.module.micronumpy import concrete
ArrayIter.__init__(self, array, size, shape, strides, backstrides, op_flags)
self.slice_shape = array.get_shape()[len(shape):]
self.slice_strides = array.strides[len(shape):]
self.slice_backstrides = array.backstrides[len(shape):]
self.shape = shape[:]
self.slice_shape = slice_shape
self.slice_stride = slice_stride
self.slice_backstride = slice_backstride
self.base = base
if op_flags.rw == 'r':
self.operand_type = concrete.NonWritableSliceArray
else:
self.operand_type = concrete.SliceArray
def getitem(self, state):
from pypy.module.micronumpy.concrete import SliceArray
assert state.iterator is self
return SliceArray(state.offset, self.slice_strides,
self.slice_backstrides, self.slice_shape, self.array,
self.array)
impl = self.operand_type
arr = impl(state.offset, [self.slice_stride], [self.slice_backstride],
[self.slice_shape], self.array, self.base)
return arr
def getitem_bool(self, state):
# XXX cannot be called
......@@ -257,8 +271,11 @@ class SliceIter(ArrayIter):
def setitem(self, state, elem):
assert state.iterator is self
slice = SliceArray(state.offset, self.slice_strides,
self.slice_backstrides, self.slice_shape, self.array,
self.array)
impl = self.operand_type
slice = impl(state.offset, [self.slice_stride], [self.slice_backstride],
[self.shape], self.array, self.base)
# TODO: implement
assert False
def getoperand(self, state, base):
return self.getitem(state)
......@@ -179,20 +179,20 @@ def coalesce_axes(it, space):
# Copy logic from npyiter_coalesce_axes, used in ufunc iterators
# and in nditer's with 'external_loop' flag
out_shape = it.shape[:]
can_coalesce = True
if it.order == 'F':
fastest = 0
else:
fastest = -1
for idim in range(it.ndim - 1):
can_coalesce = True
for op_it, _ in it.iters:
if op_it is None:
continue
assert isinstance(op_it, ArrayIter)
if len(op_it.shape_m1) < 2:
can_coalesce = False
continue
if len(op_it.shape_m1) != len(it.shape):
indx = len(op_it.strides)
if op_it.array.strides[:indx] != op_it.strides:
can_coalesce = False
break
if op_it.strides[-1] * op_it.shape_m1[-1] != op_it.backstrides[-1]:
can_coalesce = False
if can_coalesce:
if it.order == 'F':
last = out_shape[0]
......@@ -203,13 +203,60 @@ def coalesce_axes(it, space):
for i in range(len(it.iters)):
old_iter = it.iters[i][0]
shape = [s+1 for s in old_iter.shape_m1]
new_iter = SliceIter(old_iter.array, old_iter.size / last,
shape[:-1], old_iter.strides[:-1],
old_iter.backstrides[:-1], it.op_flags[i])
strides = old_iter.strides
backstrides = old_iter.backstrides
new_shape = shape[:-1]
new_strides = strides[:-1]
new_backstrides = backstrides[:-1]
_shape = shape[-1]
_stride = strides[fastest]
_backstride = backstrides[-1]
if isinstance(old_iter, SliceIter):
_shape *= old_iter.slice_shape
_stride = old_iter.slice_stride
_backstride = (_shape - 1) * _stride
new_iter = SliceIter(old_iter.array, old_iter.size / shape[-1],
new_shape, new_strides, new_backstrides,
_shape, _stride, _backstride,
it.op_flags[i], it)
if len(shape) > 1:
it.shape = out_shape
else:
it.shape = [1]
it.iters[i] = (new_iter, new_iter.reset())
else:
break
# Always coalesce at least one
if it.order == 'F':
last = out_shape[0]
out_shape = out_shape[1:]
else:
last = out_shape[-1]
out_shape = out_shape[:-1]
for i in range(len(it.iters)):
old_iter = it.iters[i][0]
shape = [s+1 for s in old_iter.shape_m1]
strides = old_iter.strides
backstrides = old_iter.backstrides
new_shape = shape[:-1]
new_strides = strides[:-1]
new_backstrides = backstrides[:-1]
_shape = shape[-1]
_stride = strides[-1]
_backstride = backstrides[-1]
if isinstance(old_iter, SliceIter):
_shape *= old_iter.slice_shape
_stride = old_iter.slice_stride
_backstride = (_shape - 1) * _stride
new_iter = SliceIter(old_iter.array, old_iter.size / shape[-1],
new_shape, new_strides, new_backstrides,
_shape, _stride, _backstride,
it.op_flags[i], it)
if len(shape) > 1:
it.shape = out_shape
else:
return
it.shape = [1]
it.iters[i] = (new_iter, new_iter.reset())
class IndexIterator(object):
def __init__(self, shape, backward=False):
......@@ -377,7 +424,7 @@ class W_NDIter(W_Root):
raise oefmt(space.w_ValueError,
"If op_axes is provided, at least one list of axes "
"must be contained within it")
raise Exception('xxx TODO')
raise oefmt(space.w_NotImplementedError, "op_axis not finished yet")
# Check that values make sense:
# - in bounds for each operand
# ValueError: Iterator input op_axes[0][3] (==3) is not a valid axis of op[0], which has 2 dimensions
......@@ -389,10 +436,7 @@ class W_NDIter(W_Root):
return space.wrap(self)
def getitem(self, it, st):
impl = it.operand_type
res = impl([], it.array.dtype, it.array.order, [], [],
it.array.storage, self)
res.start = st.offset
res = it.getoperand(st, self)
return W_NDimArray(res)
def descr_getitem(self, space, w_idx):
......@@ -411,7 +455,6 @@ class W_NDIter(W_Root):
space.wrap(len(self.iters))
def descr_next(self, space):
import pdb;pdb.set_trace()
for it, st in self.iters:
if not it.done(st):
break
......
......@@ -76,7 +76,9 @@ class AppTestNDIter(BaseNumpyAppTest):
r.append(x)
n += 1
assert n == 12
assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17], [ 9, 21], [ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23]]).all()
assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17], [ 9, 21],
[ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23],
]).all()
e = raises(ValueError, 'r[0][0] = 0')
assert str(e.value) == 'assignment destination is read-only'
r = []
......@@ -250,6 +252,10 @@ class AppTestNDIter(BaseNumpyAppTest):
a = arange(3)
import sys
b = arange(8).reshape(2,4)
if '__pypy__' in sys.builtin_module_names:
raises(NotImplementedError, nditer, [a, b, None], flags=['external_loop'],
op_axes=[[0, -1, -1], [-1, 0, 1], None])
skip('nditer op_axes not implemented yet')
it = nditer([a, b, None], flags=['external_loop'],
op_axes=[[0, -1, -1], [-1, 0, 1], None])
for x, y, z in it:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment