p0_starter.h
//===----------------------------------------------------------------------===//
//
// BusTub
//
// p0_starter.h
//
// Identification: src/include/primer/p0_starter.h
//
// Copyright (c) 2015-2020, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//
#pragma once
#include <memory>
#include <stdexcept>
#include <vector>
#include "common/exception.h"
#include "common/logger.h"
namespace bustub {
/**
* The Matrix type defines a common
* interface for matrix operations.
*/
template <typename T>
class Matrix {
protected:
/**
* TODO(P0): Add implementation
*
* Construct a new Matrix instance.
* @param rows The number of rows
* @param cols The number of columns
*
*/
Matrix(int rows, int cols) : rows_(rows), cols_(cols), linear_(nullptr) {
if (rows_ <= 0 || cols_ <= 0) {
rows_ = cols_ = -1;
return;
}
linear_ = new T[rows * cols];
}
/** The number of rows in the matrix */
int rows_;
/** The number of columns in the matrix */
int cols_;
/**
* TODO(P0): Allocate the array in the constructor.
* TODO(P0): Deallocate the array in the destructor.
* A flattened array containing the elements of the matrix.
*/
T *linear_;
public:
/** @return The number of rows in the matrix */
virtual int GetRowCount() const = 0;
/** @return The number of columns in the matrix */
virtual int GetColumnCount() const = 0;
/**
* Get the (i,j)th matrix element.
*
* Throw OUT_OF_RANGE if either index is out of range.
*
* @param i The row index
* @param j The column index
* @return The (i,j)th matrix element
* @throws OUT_OF_RANGE if either index is out of range
*/
virtual T GetElement(int i, int j) const = 0;
/**bzdf
* Set the (i,j)th matrix element.
*
* Throw OUT_OF_RANGE if either index is out of range.
*
* @param i The row index
* @param j The column index
* @param val The value to insert
* @throws OUT_OF_RANGE if either index is out of range
*/
virtual void SetElement(int i, int j, T val) = 0;
/**
* Fill the elements of the matrix from `source`.
*
* Throw OUT_OF_RANGE in the event that `source`
* does not contain the required number of elements.
*
* @param source The source container
* @throws OUT_OF_RANGE if `source` is incorrect size
*/
virtual void FillFrom(const std::vector<T> &source) = 0;
/**
* Destroy a matrix instance.
* TODO(P0): Add implementation
*/
// virtual ~Matrix() = default;
virtual ~Matrix() { delete[] linear_; }
};
/**
* The RowMatrix type is a concrete matrix implementation.
* It implements the interface defined by the Matrix type.
*/
template <typename T>
class RowMatrix : public Matrix<T> {
public:
/**
* TODO(P0): Add implementation
*
* Construct a new RowMatrix instance.
* @param rows The number of rows
* @param cols The number of columns
*/
RowMatrix(int rows, int cols) : Matrix<T>(rows, cols), data_(nullptr) {
if (this->rows_ <= 0 || this->cols_ <= 0) {
this->rows_ = -1;
this->cols_ = -1;
return;
}
data_ = new T *[this->rows_ * this->cols_];
for (int i = 0; i < this->rows_; i++) {
data_[i] = &this->linear_[i * this->cols_];
}
}
/**
* TODO(P0): Add implementation
* @return The number of rows in the matrix
*/
int GetRowCount() const override { return this->rows_; }
/**
* TODO(P0): Add implementation
* @return The number of columns in the matrix
*/
int GetColumnCount() const override { return this->cols_; }
/**
* TODO(P0): Add implementation
*
* Get the (i,j)th matrix element.
*
* Throw OUT_OF_RANGE if either index is out of range.
*
* @param i The row index
* @param j The column index
* @return The (i,j)th matrix element
* @throws OUT_OF_RANGE if either index is out of range
*/
T GetElement(int i, int j) const override {
// throw NotImplementedException{"RowMatrix::GetElement() not implemented."};
if (i >= this->rows_ || j >= this->cols_ || i < 0 || j < 0) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::GetElement(int,int) out of range:i="+i+",j="+j);
}
return data_[i][j];
}
/**
* Set the (i,j)th matrix element.
*
* Throw OUT_OF_RANGE if either index is out of range.
*
* @param i The row index
* @param j The column index
* @param val The value to insert
* @throws OUT_OF_RANGE if either index is out of range
*/
void SetElement(int i, int j, T val) override {
if (i >= this->rows_ || j >= this->cols_ || i < 0 || j < 0) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::GetElement(int,int) out of range");
}
data_[i][j] = val;
}
/**
* TODO(P0): Add implementation
*
* Fill the elements of the matrix from `source`.
*
* Throw OUT_OF_RANGE in the event that `source`
* does not contain the required number of elements.
*
* @param source The source container
* @throws OUT_OF_RANGE if `source` is incorrect size
*/
void FillFrom(const std::vector<T> &source) override {
// throw NotImplementedException{"RowMatrix::FillFrom() not implemented."};
if (static_cast<int>(source.size()) != this->rows_ * this->cols_) {
throw Exception(ExceptionType::OUT_OF_RANGE, "RowMatrix::FillFrom(vector) out of range");
}
for (int i = 0; i < this->rows_; i++) {
for (int j = 0; j < this->cols_; j++) {
SetElement(i, j, source[i * this->cols_ + j]);
}
}
}
/**
* TODO(P0): Add implementation
*
* Destroy a RowMatrix instance.
*/
// ~RowMatrix() override = default;
~RowMatrix() override { delete[] data_; }
private:
/**
* A 2D array containing the elements of the matrix in row-major format.
*
* TODO(P0):
* - Allocate the array of row pointers in the constructor.
* - Use these pointers to point to corresponding elements of the `linear` array.
* - Don't forget to deallocate the array in the destructor.
*/
T **data_;
};
/**
* The RowMatrixOperations class defines operations
* that may be performed on instances of `RowMatrix`.
*/
template <typename T>
class RowMatrixOperations {
public:
/**
* Compute (`matrixA` + `matrixB`) and return the result.
* Return `nullptr` if dimensions mismatch for input matrices.
* @param matrixA Input matrix
* @param matrixB Input matrix
* @return The result of matrix addition
*/
static std::unique_ptr<RowMatrix<T>> Add(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) {
// TODO(P0): Add implementation
int RowA = matrixA->GetRowCount();
int RowB = matrixB->GetRowCount();
int ColA = matrixA->GetColumnCount();
int ColB = matrixB->GetColumnCount();
if (RowA != RowB || ColA != ColB || !matrixA || !matrixB) {
return std::unique_ptr<RowMatrix<T>>(nullptr);
}
std::unique_ptr<RowMatrix<T>> result = std::make_unique<RowMatrix<T>>(RowA, ColA);
T temp;
for (int i = 0; i < RowA; i++) {
for (int j = 0; j < ColA; j++) {
temp = matrixA->GetElement(i, j) + matrixB->GetElement(i, j);
result->SetElement(i, j, temp);
}
}
return result;
// std::unique_ptr<RowMatrix<T>answer;
}
/**
* Compute the matrix multiplication (`matrixA` * `matrixB` and return the result.
* Return `nullptr` if dimensions mismatch for input matrices.
* @param matrixA Input matrix
* @param matrixB Input matrix
* @return The result of matrix multiplication
*/
static std::unique_ptr<RowMatrix<T>> Multiply(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) {
// TODO(P0): Add implementation
int RowA = matrixA->GetRowCount();
int RowB = matrixB->GetRowCount();
int ColA = matrixA->GetColumnCount();
int ColB = matrixB->GetColumnCount();
if (ColA != RowB || !matrixA || !matrixB) {
return std::unique_ptr<RowMatrix<T>>(nullptr);
}
auto result = std::make_unique<RowMatrix<T>>(RowA,ColB);
T temp;
for (int row = 0; row < RowA; row++) {
for (int col = 0; col < ColB; col++) {
// calculate temp for each element
temp = matrixA->GetElement(row, 0) * matrixB->GetElement(0, col);
for (int eleIdx = 1; eleIdx < ColA; eleIdx++) {
temp += matrixA->GetElement(row, eleIdx) * matrixB->GetElement(eleIdx, col);
}
result->SetElement(row,col,temp);
}
}
// LOG_INFO("teststsetsetst!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
return result;
// return std::unique_ptr<RowMatrix<T>>(nullptr);
}
/**
* Simplified General Matrix Multiply operation. Compute (`matrixA` * `matrixB` + `matrixC`).
* Return `nullptr` if dimensions mismatch for input matrices.
* @param matrixA Input matrix
* @param matrixB Input matrix
* @param matrixC Input matrix
* @return The result of general matrix multiply
*/
static std::unique_ptr<RowMatrix<T>> GEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB,
const RowMatrix<T> *matrixC) {
// TODO(P0): Add implementation
// return std::unique_ptr<RowMatrix<T>>(nullptr);
return Add(matrixC,Multiply(matrixA,matrixB).get());
}
};
} // namespace bustub