To run Llama 3.1-8B-instruct model on a local CPU with 4 GB ram without quantization. By Loading and Running a LLaMA Model on CPU with Disk-based Layer Loading.
Posted by Lord_Momus@reddit | programming | View on Reddit | 14 comments
I am trying to run 3.1 8B llama instruct model https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct on a 4GB ram laptop. The idea I'm using is to load and run one layer at a time.
I have a class.
It initializes key components of the LLaMA architecture:
LlamaTokenEmbed: Handles token embeddings.
LlamaLayer: Represents a transformer block.
LlamaFinalLayerNorm: Normalizes the output before final predictions.
LlamaFinalLayerHead: Generates final token probabilities.
Running Inference (run method)
It processes the tokens through the embedding layer.
Then, it iterates over 32 transformer layers (LlamaLayer) by Loading the corresponding layer weights from disk. Runs the layer on the input tensor x.
After all layers are processed, the final normalization and output head compute the final model output.
Here's the code
class LlamaCpuDiskRun():
def __init__(self,config):
self.config = config
self.freqs_complex = precompute_theta_pos_frequencies(self.config.dim // self.config.n_heads, self.config.max_position_embeddings * 2, device = self.config.device)
self.llamatoken = LlamaTokenEmbed(self.config)
self.llamalayer = LlamaLayer(self.config,self.freqs_complex)
self.llamafinalnorm = LlamaFinalLayerNorm(self.config)
self.llamafinallmhead = LlamaFinalLayerHead(self.config)
prev_time = time.time()
self.llamatoken.load_state_dict(load_file(config.model_dir + "/separated_weights/embed_tokens.safetensors"), strict=True)
print(time.time() - prev_time)
self.llamafinalnorm.load_state_dict(load_file(config.model_dir + "/separated_weights/norm.safetensors"), strict=True)
self.llamafinallmhead.load_state_dict(load_file(config.model_dir + "/separated_weights/lm_head.safetensors"), strict=True)
def run(self,tokens : torch.Tensor, curr_pos: int):
total_time = time.time()
x = self.llamatoken(tokens)
layer_time_avg = 0
layer_load_t_avg = 0
for i in range(0,32):
print(f"layer{i}")
prev_time = time.time()
self.llamalayer.load_state_dict(load_file(self.config.model_dir + f"/separated_weights/layers{i}.safetensors"), strict=True)
t = time.time() - prev_time
layer_load_t_avg += t
print(t)
prev_time = time.time()
x = self.llamalayer(x,curr_pos)
t = time.time() - prev_time
layer_time_avg += t
print(t)
print("final layers")
prev_time = time.time()
x = self.llamafinallmhead(self.llamafinalnorm(x))
print(time.time() - prev_time)
print(x.shape)
print("total time")
print(time.time() - total_time)
print(f"average layer compute and load time:{layer_time_avg/32},{layer_load_t_avg/32}" )
class LlamaCpuDiskRun():
def __init__(self,config):
self.config = config
self.freqs_complex = precompute_theta_pos_frequencies(self.config.dim // self.config.n_heads, self.config.max_position_embeddings * 2, device = self.config.device)
self.llamatoken = LlamaTokenEmbed(self.config)
self.llamalayer = LlamaLayer(self.config,self.freqs_complex)
self.llamafinalnorm = LlamaFinalLayerNorm(self.config)
self.llamafinallmhead = LlamaFinalLayerHead(self.config)
prev_time = time.time()
self.llamatoken.load_state_dict(load_file(config.model_dir + "/separated_weights/embed_tokens.safetensors"), strict=True)
print(time.time() - prev_time)
self.llamafinalnorm.load_state_dict(load_file(config.model_dir + "/separated_weights/norm.safetensors"), strict=True)
self.llamafinallmhead.load_state_dict(load_file(config.model_dir + "/separated_weights/lm_head.safetensors"), strict=True)
def run(self,tokens : torch.Tensor, curr_pos: int):
total_time = time.time()
x = self.llamatoken(tokens)
layer_time_avg = 0
layer_load_t_avg = 0
for i in range(0,32):
print(f"layer{i}")
prev_time = time.time()
self.llamalayer.load_state_dict(load_file(self.config.model_dir + f"/separated_weights/layers{i}.safetensors"), strict=True)
t = time.time() - prev_time
layer_load_t_avg += t
print(t)
prev_time = time.time()
x = self.llamalayer(x,curr_pos)
t = time.time() - prev_time
layer_time_avg += t
print(t)
print("final layers")
prev_time = time.time()
x = self.llamafinallmhead(self.llamafinalnorm(x))
print(time.time() - prev_time)
print(x.shape)
print("total time")
print(time.time() - total_time)
print(f"average layer compute and load time:{layer_time_avg/32},{layer_load_t_avg/32}" )
Output:
total time
27.943154096603394
average layer compute and load time:0.03721388429403305,0.8325831741094589
The weights loading part takes most of the time 0.832*32 = 26.624 seconds, compute takes 0.037 * 32 = 1.18 seconds.
The compute is 22 times faster than loading the weights part.
I am looking for ideas to minimize the weights loading time. Any idea on how I can improve this?
raiango@reddit
I haven’t looked into the implementation but do you know if there layers are fully connected? Could you avoid loading the nodes in the next layer that don’t receive input?
Lord_Momus@reddit (OP)
There are 32 blocks inside which we have 3 feed forward layers. What you have mentioned is what I have implemented. As i have mentioned the loading of weights for an layer is taking a lot of time.
One_Being7941@reddit
How do we get it where it doesn't censor things?
Lord_Momus@reddit (OP)
Do you mean the base llama model before it is fine tuned?
One_Being7941@reddit
I use ollama run dolphin-llama3
Wonderful-Wind-5736@reddit
Are the weights compressed? If not do so and load them on a second thread. I/O-bound operations don’t get penalized by the GIL.
Lord_Momus@reddit (OP)
The weights are not compressed. Thanks for the idea. I will try to compress them and load the it via second thread.
The thing is I don't want to loose any information regarding my weights(That is why I didn't do quantization). Do you any compression technique which I could look into based on what I said?
Thanks for the reply!!
Wonderful-Wind-5736@reddit
GZIP at high compression ratio?
deeringc@reddit
I don't expect that there's a lot of opportunity for compression to reduce the size of LLM weights. The nature of this data is already "compressed" in some sense. Training a model is like an extremely elaborate lossy compression of the training data. If you take a zip file and compress it, the resulting file won't be smaller.
Wonderful-Wind-5736@reddit
Yeah, numerical data usually compresses badly with byte based algorithms. But it’s so easy to implement, OP might as-well try it. Ideally he‘d probably quantize and compress. Since these weights are used as linear operators he could also try some spectral methods, accuracy is quite easy to tune with those.
puddingfox@reddit
Lord_Momus@reddit (OP)
Faster loading of weights from disk to ram.
puddingfox@reddit
You could be loading layer n+1 while running layer n. But base don your numbers that could not save you much time. You could look at what the PyTorch
save_file
andload_file
functions do - maybe they can be optimized for your specific use case. e.g. maybe if all your data is float8 parameters then the logic could be simplified?Lord_Momus@reddit (OP)
Nice idea. I thought of the same, will try to implement it to see how much time I could save. But as you mentioned I don;t think it will save much time.
The parameters are BF16. I did'nt understand fully about the optimizatio of
save_file
andload_file.
Could you please elaborate?