[go: up one dir, main page]

Skip to content

Commit

Permalink
attention use the packed matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Aug 9, 2023
1 parent 20b308e commit ebbaac8
Show file tree
Hide file tree
Showing 10 changed files with 409 additions and 343 deletions.
105 changes: 83 additions & 22 deletions src/core/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ size_t AttentionBase::get_workspace_in_byte() {
uint32_t K = inputs()[0]->shape()[1];
uint32_t N = weights()[0]->shape()[0];
auto kernel = get_kernel();

if (m_packed_weight) {
N *= PACK_SIZE;
}
uint32_t seqlen = input->shape()[0];

size_t total = 0;
Expand Down Expand Up @@ -326,6 +328,23 @@ size_t AttentionBase::get_workspace_in_byte() {
return total;
}

std::vector<size_t> AttentionBase::preprocess_weight(
Tensor* tensor, void* src, void* dst) {
INFER_ASSERT(tensor->dtype() == DType::Int4, "only support optimized int4 kernel");
auto weight_shape = tensor->shape();
size_t M = weight_shape[0];
size_t N = weight_shape[1];
INFER_ASSERT(N % QK40 == 0, "error of embd size.");
INFER_ASSERT(M % PACK_SIZE == 0, "the M in matmul is not align to 8.");

auto kernel = get_kernel();
kernel->operator()<KernelID::MatmulInt4WeightReorder>(M, N, dst, src, PACK_SIZE);
size_t block_m = M / PACK_SIZE;

m_packed_weight = true;
return {block_m, N * PACK_SIZE};
}

void LlamaAttention::execute(WorkSpace* workspace, uint32_t nr_past) {
INFER_ASSERT(
nr_past == m_kstorage->current_index(),
Expand Down Expand Up @@ -398,12 +417,27 @@ void LlamaAttention::execute(WorkSpace* workspace, uint32_t nr_past) {
float* p_outq = static_cast<float*>(q_out);
switch (w_dtype) {
case DType::Int4:
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work, size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work, size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work, size);
if (!m_packed_weight) {
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work,
size);
} else {
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work,
size);
}
break;
case DType::Int8:
kernel->operator()<KernelID::MatmulInt8Float>(
Expand Down Expand Up @@ -541,12 +575,27 @@ void GlmAttention::execute(WorkSpace* workspace, uint32_t nr_past) {
float* p_outq = static_cast<float*>(q_out);
switch (w_dtype) {
case DType::Int4:
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work, size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work, size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work, size);
if (!m_packed_weight) {
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work,
size);
} else {
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outk, p_wk, p_bk, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outv, p_wv, p_bv, pdata, seqlen, embd, embd, p_work,
size);
}
break;
case DType::Int8:
kernel->operator()<KernelID::MatmulInt8Float>(
Expand Down Expand Up @@ -673,15 +722,27 @@ void Glm2MultiQueryAttention::execute(WorkSpace* workspace, uint32_t nr_past) {
float* p_outq = static_cast<float*>(q_out);
switch (w_dtype) {
case DType::Int4:
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work, size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, kv_length, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, kv_length, embd, p_work,
size);
break;
if (!m_packed_weight) {
kernel->operator()<KernelID::MatmulInt4Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outk, p_wk, p_bk, pdata, seqlen, kv_length, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4Float>(
p_outv, p_wv, p_bv, pdata, seqlen, kv_length, embd, p_work,
size);
} else {
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outk, p_wk, p_bk, pdata, seqlen, kv_length, embd, p_work,
size);
kernel->operator()<KernelID::MatmulInt4FloatPacked>(
p_outv, p_wv, p_bv, pdata, seqlen, kv_length, embd, p_work,
size);
}
case DType::Int8:
kernel->operator()<KernelID::MatmulInt8Float>(
p_outq, p_wq, p_bq, pdata, seqlen, embd, embd, p_work, size);
Expand Down
27 changes: 25 additions & 2 deletions src/core/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
namespace inferllm {

using OpIOs = std::vector<std::shared_ptr<Tensor>>;
constexpr static size_t PACK_SIZE = 8;

//! Base class of an Op, the call step is:
//! call deduce_output_shape to get the output tensor shape
Expand Down Expand Up @@ -161,10 +162,11 @@ class MatMul : public OpBase {
auto kernel = get_kernel();
//! only when the weight is int4
if (weight->name() == weights()[0]->name()) {
size_t M = weight->shape()[0];
bool optimized =
kernel->supported_optimization(KernelOptMethod::MatmulInt4Reorder);
bool int4 = weight->dtype() == DType::Int4;
if (optimized && int4) {
if (optimized && int4 && (M % PACK_SIZE == 0)) {
return true;
}
}
Expand All @@ -179,7 +181,6 @@ class MatMul : public OpBase {
size_t get_workspace_in_byte() override;

bool m_bias = false;
constexpr static size_t PACK_SIZE = 8;
bool m_weight_packed = false;
};

Expand Down Expand Up @@ -384,13 +385,35 @@ class AttentionBase : public OpBase {
m_vstorage->reset_id();
}

virtual bool need_preprocess_weight(Tensor* weight) override {
auto kernel = get_kernel();
bool int4 = weight->dtype() == DType::Int4;
size_t M = weight->shape()[0];
bool right_weight = false;
bool optimized =
kernel->supported_optimization(KernelOptMethod::MatmulInt4Reorder);
//! only when the weight is int4
if (m_fused_weights) {
right_weight = weight->name() == weights()[0]->name();
} else {
right_weight = weight->name() == weights()[0]->name() ||
weight->name() == weights()[1]->name() ||
weight->name() == weights()[2]->name();
}
return optimized && int4 && right_weight && M % PACK_SIZE == 0;
}

virtual std::vector<size_t> preprocess_weight(
Tensor* tensor, void* src, void* dst) override;

protected:
uint32_t m_embd;
uint32_t m_head;
uint32_t m_ctx;
uint32_t m_layer_id;
bool m_fused_weights;
bool m_bias;
bool m_packed_weight = false;

std::unique_ptr<KvStorage> m_kstorage;
std::unique_ptr<KvStorage> m_vstorage;
Expand Down
27 changes: 27 additions & 0 deletions src/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ size_t Tensor::read_data_from_file() {
return length;
}

void Tensor::preprocess_data() {
size_t length = length_in_byte();
INFER_ASSERT(m_data, "m_data should be not null when preprocess data.");
//! no unified memory, we need read data to host memory and copy to device
auto opr = this->owner_op();
if (!m_device->unified_memory()) {
if (opr->need_preprocess_weight(this)) {
auto host_src = m_device->allocate_host(length);
auto host_dst = m_device->allocate_host(length);
m_device->device2host_copy(host_src, m_data, length);
auto shape = opr->preprocess_weight(this, host_src, host_dst);
m_device->host2device_copy(m_data, host_dst, length);
set_shape(shape);
m_device->free_host(host_src);
m_device->free_host(host_dst);
}
} else {
if (opr->need_preprocess_weight(this)) {
void* new_data = m_device->allocate(length);
auto shape = opr->preprocess_weight(this, m_data, new_data);
set_shape(shape);
m_device->free_device(m_data);
m_data = new_data;
}
}
}

void Tensor::set_shared_memory(void* data, size_t size) {
INFER_ASSERT(
data == nullptr || size >= length_in_byte(),
Expand Down
2 changes: 2 additions & 0 deletions src/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class Tensor {

size_t read_data_from_file();

void preprocess_data();

private:
bool m_shared = false;
int32_t m_usr_count = 0;
Expand Down
10 changes: 7 additions & 3 deletions src/kern/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ class Kernel {
}

bool supported_optimization(KernelOptMethod method) {
if (m_kernel_type == KernelType::Arm || m_kernel_type == KernelType::Naive ||
m_kernel_type == KernelType::X86) {
if (m_kernel_type == KernelType::Arm || m_kernel_type == KernelType::Naive) {
if (method == KernelOptMethod::MatmulInt4Reorder) {
#if defined(__ARM_FEATURE_DOTPROD)
return true;
#else
return false;
#endif
}
return false;
}
return false;
}
Expand Down Expand Up @@ -77,6 +81,6 @@ class Kernel {
void set_handle(cudaHandle* handle) { m_handle = handle; }
cudaHandle* m_handle;
#endif
};
};

} // namespace inferllm
7 changes: 5 additions & 2 deletions src/kern/optimized/arm/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,11 @@ TaskSet llm_matmul_compute_int4_float_packed(
int8_t* src = q_src + m * weight_q80_stride;
float* dst_ptr = dst + m * N + n * 8;
const float* bias_ptr = bias ? bias + n * 8 : nullptr;
//vec_vec_dot_q40_with_q80_packed(K, q_weight, src, dst_ptr, bias_ptr);
vec_vec_dot_q40_with_q80_packed_asm(K, q_weight, src, dst_ptr, bias_ptr);
#if defined(__ARM_FEATURE_DOTPROD)
vec_vec_dot_q40_with_q80_packed(K, q_weight, src, dst_ptr, bias_ptr);
// vec_vec_dot_q40_with_q80_packed_asm(K, q_weight, src, dst_ptr,
// bias_ptr);
#endif
}
}
};
Expand Down
Loading

0 comments on commit ebbaac8

Please sign in to comment.