[go: up one dir, main page]

Skip to content

Commit

Permalink
optimize the int4 matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Aug 3, 2023
1 parent 3eafdeb commit 239d4d7
Show file tree
Hide file tree
Showing 13 changed files with 646 additions and 19 deletions.
44 changes: 37 additions & 7 deletions src/core/op.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "op.h"
#include <fstream>
#include <iostream>
#include <string>
#include "kern/kernel.h"
#include "kern/naive/naive.h"
using namespace inferllm;
Expand Down Expand Up @@ -153,7 +154,7 @@ void SpliteHalfActiveMul::execute(WorkSpace*, uint32_t) {

void MatMul::execute(WorkSpace* workspace, uint32_t) {
auto N = weights()[0]->shape()[0];
auto K = weights()[0]->shape()[1];
auto K = inputs()[0]->shape()[1];
auto M = inputs()[0]->shape()[0];
auto src_dtype = inputs()[0]->dtype();
auto weight_dtype = weights()[0]->dtype();
Expand All @@ -169,9 +170,15 @@ void MatMul::execute(WorkSpace* workspace, uint32_t) {
const float* src = inputs()[0]->ptr<float>();
switch (weight_dtype) {
case DType::Int4:
kernel->operator()<KernelID::MatmulInt4Float>(
dst, weights()[0]->ptr(), bias, src, M, N, K, p_workspace,
p_workspace_size);
if (!m_weight_packed) {
kernel->operator()<KernelID::MatmulInt4Float>(
dst, weights()[0]->ptr(), bias, src, M, N, K, p_workspace,
p_workspace_size);
} else {
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
dst, weights()[0]->ptr(), bias, src, M, N * PACK_SIZE, K,
p_workspace, p_workspace_size);
}
break;
case DType::Int8:
kernel->operator()<KernelID::MatmulInt8Float>(
Expand Down Expand Up @@ -203,6 +210,23 @@ size_t MatMul::get_workspace_in_byte() {
return 0;
}

//! all the memory is the host memory
std::vector<size_t> MatMul::preprocess_weight(Tensor* tensor, void* src, void* dst) {
INFER_ASSERT(tensor->dtype() == DType::Int4, "only support optimized int4 kernel");
auto weight_shape = tensor->shape();
size_t M = weight_shape[0];
size_t N = weight_shape[1];
INFER_ASSERT(N % QK40 == 0, "error of embd size.");
INFER_ASSERT(M % PACK_SIZE == 0, "the M in matmul is not align to 8.");

auto kernel = get_kernel();
kernel->operator()<KernelID::MatmulInt4WeightReorder>(M, N, dst, src, PACK_SIZE);
size_t block_m = M / PACK_SIZE;

m_weight_packed = true;
return {block_m, N * PACK_SIZE};
}

void MatMulLast::execute(WorkSpace* workspace, uint32_t) {
auto N = weights()[0]->shape()[0];
auto K = weights()[0]->shape()[1];
Expand All @@ -223,9 +247,15 @@ void MatMulLast::execute(WorkSpace* workspace, uint32_t) {
const float* src = inputs()[0]->ptr<float>() + (row - 1) * K;
switch (weight_dtype) {
case DType::Int4:
kernel->operator()<KernelID::MatmulInt4Float>(
dst, weights()[0]->ptr(), bias, src, M, N, K, p_workspace,
p_workspace_size);
if (!m_weight_packed) {
kernel->operator()<KernelID::MatmulInt4Float>(
dst, weights()[0]->ptr(), bias, src, M, N, K, p_workspace,
p_workspace_size);
} else {
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
dst, weights()[0]->ptr(), bias, src, M, N * PACK_SIZE, K,
p_workspace, p_workspace_size);
}
break;
case DType::Int8:
kernel->operator()<KernelID::MatmulInt8Float>(
Expand Down
37 changes: 36 additions & 1 deletion src/core/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ class OpBase {
OpIOs outputs() { return m_outputs; }
std::string name() { return m_name; }

//! for better optimized the compute, some op need preprocess the weight, so that
//! the compute is friendly to the compute kernel
virtual bool need_preprocess_weight(Tensor*) { return false; }

virtual std::vector<size_t> preprocess_weight(
Tensor* tensor, void* src, void* dst) {
return std::vector<size_t>();
}

private:
Device* m_device;
OpIOs m_weights;
Expand Down Expand Up @@ -141,15 +150,37 @@ class MatMul : public OpBase {
auto weight_shape = weights()[0]->shape();
auto input_shape = inputs()[0]->shape();
size_t M = input_shape.size() == 2 ? input_shape[0] : input_shape[1];
size_t K = weight_shape[1];
size_t N = weight_shape[0];
if (m_weight_packed) {
N = N * PACK_SIZE;
}
outputs()[0]->set_shape({M, N}, inputs()[0]->dtype());
}

virtual bool need_preprocess_weight(Tensor* weight) override {
auto kernel = get_kernel();
//! only when the weight is int4
if (weight->name() == weights()[0]->name()) {
bool optimized =
kernel->supported_optimization(KernelOptMethod::MatmulInt4Reorder);
bool int4 = weight->dtype() == DType::Int4;
if (optimized && int4) {
return true;
}
}
return false;
}

virtual std::vector<size_t> preprocess_weight(
Tensor* tensor, void* src, void* dst) override;

virtual void execute(WorkSpace* workspace, uint32_t nr_past) override;

size_t get_workspace_in_byte() override;

bool m_bias = false;
constexpr static size_t PACK_SIZE = 8;
bool m_weight_packed = false;
};

class MatMulLast : public MatMul {
Expand All @@ -163,9 +194,13 @@ class MatMulLast : public MatMul {
size_t M = 1;
size_t K = weight_shape[1];
size_t N = weight_shape[0];
if (m_weight_packed) {
N = N * PACK_SIZE;
}
outputs()[0]->set_shape({M, N}, inputs()[0]->dtype());
}
void execute(WorkSpace* workspace, uint32_t nr_past) override;
virtual bool need_preprocess_weight(Tensor*) override { return false; }

size_t get_workspace_in_byte() override;
};
Expand Down
23 changes: 21 additions & 2 deletions src/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "../kern/kernel_define.h"
#include "memory.h"
#include "utils.h"
#include "op.h"

using namespace inferllm;

Expand Down Expand Up @@ -90,12 +91,30 @@ size_t Tensor::read_data_from_file() {
if (!m_device->unified_memory()) {
m_data = m_device->allocate(length);
auto host_ptr = m_device->allocate_host(length);
m_file->read_data(host_ptr, length, m_file_offset);
auto opr = this->owner_op();
if (opr->need_preprocess_weight(this)) {
auto host_ptr2 = m_device->allocate_host(length);
m_file->read_data(host_ptr2, length, m_file_offset);
auto shape = opr->preprocess_weight(this, host_ptr2, host_ptr);
set_shape(shape);
m_device->free_host(host_ptr2);
} else {
m_file->read_data(host_ptr, length, m_file_offset);
}
m_device->host2device_copy(m_data, host_ptr, length);
m_device->free_host(host_ptr);
} else {
m_data = m_device->allocate(length);
m_file->read_data(m_data, length, m_file_offset);
auto opr = this->owner_op();
if (opr->need_preprocess_weight(this)) {
auto host_data = m_device->allocate_host(length);
m_file->read_data(host_data, length, m_file_offset);
auto shape = opr->preprocess_weight(this, host_data, m_data);
set_shape(shape);
m_device->free_host(host_data);
} else {
m_file->read_data(m_data, length, m_file_offset);
}
}
}
return length;
Expand Down
3 changes: 3 additions & 0 deletions src/kern/gpu/kernel_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -762,5 +762,8 @@ size_t llm_matmul_get_workspace_float(uint32_t M, uint32_t N, uint32_t K) {
size_t llm_matmul_get_workspace_float_float(uint32_t M, uint32_t N, uint32_t K) {
return 0;
}

void llm_int4_matmul_weight_reorder(
size_t M, size_t N, void* dst, void* src, size_t PACK_SIZE) {}
} // namespace gpu
} // namespace inferllm
4 changes: 4 additions & 0 deletions src/kern/gpu/kernel_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ void llm_head_batched_matmul_compute_float(
float* dst, const float* v, const float* qk, uint32_t seqlen, uint32_t embd,
uint32_t head, uint32_t nr_past, cudaHandle* handle);

void llm_int4_matmul_weight_reorder(
size_t M, size_t N, void* dst, void* src, size_t PACK_SIZE);

template <KernelID Id, typename... Args>
struct Comp {
static void exec(Args... args, cudaHandle* handle);
Expand Down Expand Up @@ -167,6 +170,7 @@ PartialImplementKernel(
llm_matmul_compute_with_head_strideq_broadcastk_float);
PartialImplementKernel(
HeadBatchedMatmulBroadCastVFloat, llm_head_batched_matmul_broadcastv_float);
PartialImplementKernel(MatmulInt4WeightReorder, llm_int4_matmul_weight_reorder);

PartialImplementSpace(MatmulInt4Float, llm_matmul_get_workspace_float);
PartialImplementSpace(MatmulFloatFloat, llm_matmul_get_workspace_float_float);
Expand Down
10 changes: 10 additions & 0 deletions src/kern/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ class Kernel {
return m_thread_pool->nr_threads();
}

bool supported_optimization(KernelOptMethod method) {
if (m_kernel_type == KernelType::Arm || m_kernel_type == KernelType::Naive ||
m_kernel_type == KernelType::X86) {
if (method == KernelOptMethod::MatmulInt4Reorder) {
return true;
}
}
return false;
}

//! compute
template <KernelID Id, typename... Args>
void operator()(Args... args) {
Expand Down
14 changes: 14 additions & 0 deletions src/kern/kernel_define.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum class KernelID {
RmsNormFloat,
SoftmaxFloat,
MatmulInt4Float,
MatmulInt4FloatPacked,
MatmulInt8Float,
MatmulFloatFloat,
MatmulWithHeadStrideFloat,
Expand All @@ -39,6 +40,11 @@ enum class KernelID {
DiagMaskFloat,
GlmGmask,
PermuteFloat,
MatmulInt4WeightReorder,
};

enum class KernelOptMethod {
MatmulInt4Reorder = 0,
};

enum class ElemMode {
Expand Down Expand Up @@ -75,12 +81,20 @@ struct BlockQ40 {
float d; // delta
uint8_t qs[QK40 / 2]; // nibbles / quants
};
static_assert(sizeof(BlockQ40) == 20, "BlockQ40 size error");

struct BlockQ40X8 {
float scale[8]; // delta
uint8_t qs[QK40 / 2 * 8]; // nibbles / quants
};
static_assert(sizeof(BlockQ40X8) == 160, "BlockQ40X8 size error");

#define QK80 32
struct BlockQ80 {
float d; // delta
int8_t qs[QK80]; // nibbles
};
static_assert(sizeof(BlockQ80) == 36, "BlockQ80 size error");
} // namespace inferllm

#define PartialImplementKernel(kernel_id, fun) \
Expand Down
Loading

0 comments on commit 239d4d7

Please sign in to comment.