Compressed Sparse Row Format
This format tries to compress the sparse matrix further compared to COO format. Suppose you have
the following coordinate representation of a sparse matrix where you sort by row index:
rows = [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6]
cols = [1, 2, 4, 0, 2, 3, 0, 1, 3, 4, 1, 2, 5, 6, 0, 2, 5, 3, 4, 6, 3, 5]
values = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
If there are nnz nonzeroes, then this representation requires 3*nnz units of storage since it needs three
arrays, each of length nnz.
Observe that when we sort by row index, values are repeated wherever there is more than one
nonzero in a row. The idea behind CSR is to exploit this redundancy. From the sorted COO
representation, we keep the column indices and values as-is. Then, instead of storing every row
index, we just store the starting offset of each row in those two lists, which we'll refer to as the row
pointers, stored as the list rowptr, below:
cols = [1, 2, 4, 0, 2, 3, 0, 1, 3, 4, 1, 2, 5, 6, 0, 2, 5, 3, 4, 6, 3, 5]
values = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
rowptr = [0, 3, 6, 10, 14, 17, 20, 22]
If the sparse matrix has n rows, then the rowptr list has n+1 elements, where the last element
(rowptr[-1] == rowptr[n]) is nnz. (Why might you need this last element?)
Exercise 8 (3 points). Complete the function, coo2csr(coo_rows, coo_cols, coo_vals), below. The
inputs are three Python lists corresponding to a sparse matrix in COO format, like the example
illustrated above. Your function should return a triple, (csr_ptrs, csr_inds, csr_vals), corresponding to
the same matrix but stored in CSR format, again, like what is shown above, where csr_ptrs would be
the row pointers (rowptr), csr_inds would be the column indices (colind), and csr_vals would be the
values (values).
To help you out, we show how to calculate csr_inds and csr_vals. You need to figure out how to
compute csr_ptrs. The function is set up to return these three lists.
Although the test cell does not check it, in principle, your implementation should also work correctly
if a row has an empty row. In such cases, what should the CSR data structure look like?
Write code here
def coo2csr(coo_rows, coo_cols, coo_vals):
from operator import itemgetter
C = sorted(zip(coo_rows, coo_cols, coo_vals), key=itemgetter(0))
nnz = len(C)
assert nnz >= 1
csr_inds = [j for _, j, _ in C]
csr_vals = [a_ij for _, _, a_ij in C]
# Your task: Compute `csr_ptrs`
###
, ### YOUR CODE HERE
Testing codes
# Test cell 0: `create_csr_test` (1 point)
csr_ptrs, csr_inds, csr_vals = coo2csr(coo_rows, coo_cols, coo_vals)
assert type(csr_ptrs) is list, "`csr_ptrs` is not a list."
assert type(csr_inds) is list, "`csr_inds` is not a list."
assert type(csr_vals) is list, "`csr_vals` is not a list."
assert len(csr_ptrs) == (num_verts + 1), "`csr_ptrs` has {} values instead of {}".format(len(csr_ptrs),
num_verts+1)
assert len(csr_inds) == num_edges, "`csr_inds` has {} values instead of {}".format(len(csr_inds),
num_edges)
assert len(csr_vals) == num_edges, "`csr_vals` has {} values instead of {}".format(len(csr_vals),
num_edges)
assert csr_ptrs[num_verts] == num_edges, "`csr_ptrs[{}]` == {} instead of {}".format(num_verts,
csr_ptrs[num_verts], num_edges)
# Check some random entries
for i in sample(range(num_verts), 10000):
assert i in G
a, b = csr_ptrs[i], csr_ptrs[i+1]
msg_prefix = "Row {} should have these nonzeros: {}".format(i, G[i])
assert (b-a) == len(G[i]), "{}, which is {} nonzeros; instead, it has just {}.".format(msg_prefix,
len(G[i]), b-a)
assert all([(j in G[i]) for j in csr_inds[a:b]]), "{}. However, it may have missing or incorrect column
indices: csr_inds[{}:{}] == {}".format(msg_prefix, a, b, csr_inds[a:b])
assert all([(j in csr_inds[a:b] for j in G[i].keys())]), "{}. However, it may have missing or incorrect
column indices: csr_inds[{}:{}] == {}".format(msg_prefix, a, b, csr_inds[a:b])
print ("n(Passed.)")