Projects/matrix_mult.py

43 lines
1.2 KiB
Python

from typing import List
"""
Multiplies two matrices M and N and returns the resulting matrix.
Args:
M (List[List]): The first matrix as a list of lists, with dimensions m x k.
N (List[List]): The second matrix as a list of lists, with dimensions k x n.
Returns:
List[List]: The product matrix as a list of lists, with dimensions m x n.
Raises:
AssertionError: If the number of columns in M does not equal the number of rows in N.
Example:
>>> M = [
... [1, 2, 4],
... [0, 1, 2]
... ]
>>> N = [
... [1, 2],
... [3, 4],
... [0, 1]
... ]
>>> multiply(M, N)
[[7, 14], [3, 6]]
NOTE: If you import the matrix_mult module into another project, call it with matrix_mult.multiply(M, N).
"""
def multiply(M: List[List], N: List[List]) -> List[List]:
mrows, mcols = len(M), len(M[0])
nrows, ncols = len(N), len(N[0])
assert mcols == nrows, "Matrices are not compatible for multiplication"
# Initialize result matrix with zeros
result = [[0 for _ in range(ncols)] for _ in range(mrows)]
for i in range(mrows):
for j in range(ncols):
for k in range(mcols):
result[i][j] += M[i][k] * N[k][j]
return result