AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


开源!我用Qwen2VL实现了一个多模态RAG
发布日期:2024-09-09 22:20:57 浏览次数: 1644


ColPali 是一种多模态检索器,直接对图像进行处理,无需OCR。

对数据建立索引后,使用 Qwen2-VL-7B 完成 RAG 的生成部分。

from pdf2image import convert_from_path

images = convert_from_path("/content/climate_youth_magazine.pdf")
images[5]

byaldi 是 answer.ai 开源的工具包,可轻松使用 ColPali

from byaldi import RAGMultiModalModel

RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")

建立索引

RAG.index(
    input_path="/content/climate_youth_magazine.pdf",
    index_name="image_index"# index will be saved at index_root/index_name/
    store_collection_with_index=False,
    overwrite=True
)

然后就可以搜索了

text_query = "How much did the world temperature change so far?"
results = RAG.search(text_query, k=1)
results

[{'doc_id': 0, 'page_num': 6, 'score': 17.25, 'metadata': {}, 'base64': None}]

答案确实是在第6页,就是上面展示的那页pdf。现在我们可以构建一个 RAG 管道了。使用 Qwen2-VL-7B 模型。

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct",
                                                        trust_remote_code=True, torch_dtype=torch.bfloat16).cuda().eval()
                                                        
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)

image_index = results[0]["page_num"] - 1
messages = [
    {
        "role""user",
        "content": [
            {
                "type""image",
                "image": images[image_index],
            },
            {"type""text""text": text_query},
        ],
    }
]

text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

["The Earth's average global temperature has increased by around 1.1°C since the late 19th century, according to the information provided in the image."] 答案正确!



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询