зеркало из
				https://github.com/ssciwr/AMMICO.git
				synced 2025-10-30 21:46:04 +02:00 
			
		
		
		
	 8118aa4830
			
		
	
	
		8118aa4830
		
			
		
	
	
	
	
		
			
			* changed `3_non-deterministic summary` to `3_non-deterministic_summary` * changed cvlib to working fork * disabled test_display
		
			
				
	
	
		
			152 строки
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			152 строки
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import pytest
 | |
| from torch import device, cuda
 | |
| from lavis.models import load_model_and_preprocess
 | |
| import ammico.summary as sm
 | |
| 
 | |
| 
 | |
| IMAGES = ["d755771b-225e-432f-802e-fb8dc850fff7.png", "IMG_2746.png"]
 | |
| 
 | |
| SUMMARY_DEVICE = device("cuda" if cuda.is_available() else "cpu")
 | |
| 
 | |
| TEST_KWARGS = {
 | |
|     "run1": {
 | |
|         "name": "blip_caption",
 | |
|         "model_type": "base_coco",
 | |
|         "is_eval": True,
 | |
|         "device": SUMMARY_DEVICE,
 | |
|     },
 | |
|     "run2": {
 | |
|         "name": "blip_caption",
 | |
|         "model_type": "base_coco",
 | |
|         "is_eval": True,
 | |
|         "device": SUMMARY_DEVICE,
 | |
|     },
 | |
|     "run3": {
 | |
|         "name": "blip_caption",
 | |
|         "model_type": "large_coco",
 | |
|         "is_eval": True,
 | |
|         "device": SUMMARY_DEVICE,
 | |
|     },
 | |
| }
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def get_dict(get_path):
 | |
|     mydict = {}
 | |
|     for img in IMAGES:
 | |
|         id_ = os.path.splitext(os.path.basename(img))[0]
 | |
|         mydict[id_] = {"filename": get_path + img}
 | |
|     return mydict
 | |
| 
 | |
| 
 | |
| # @pytest.mark.long
 | |
| def test_analyse_image(get_dict):
 | |
|     reference_results = {
 | |
|         "run1": {
 | |
|             "d755771b-225e-432f-802e-fb8dc850fff7": "a river running through a city next to tall buildings",
 | |
|             "IMG_2746": "a crowd of people standing on top of a tennis court",
 | |
|         },
 | |
|         "run2": {
 | |
|             "d755771b-225e-432f-802e-fb8dc850fff7": "a river running through a city next to tall buildings",
 | |
|             "IMG_2746": "a crowd of people standing on top of a tennis court",
 | |
|         },
 | |
|         "run3": {
 | |
|             "d755771b-225e-432f-802e-fb8dc850fff7": "a river running through a town next to tall buildings",
 | |
|             "IMG_2746": "a crowd of people standing on top of a track",
 | |
|         },
 | |
|     }
 | |
|     # test three different models
 | |
|     for test_run in TEST_KWARGS.keys():
 | |
|         summary_model, summary_vis_processors, _ = load_model_and_preprocess(
 | |
|             **TEST_KWARGS[test_run]
 | |
|         )
 | |
|         # run two different images
 | |
|         for key in get_dict.keys():
 | |
|             get_dict[key] = sm.SummaryDetector(
 | |
|                 get_dict[key],
 | |
|                 analysis_type="summary",
 | |
|                 summary_model=summary_model,
 | |
|                 summary_vis_processors=summary_vis_processors,
 | |
|             ).analyse_image()
 | |
|         assert len(get_dict) == 2
 | |
|         for key in get_dict.keys():
 | |
|             assert len(get_dict[key]["3_non-deterministic_summary"]) == 3
 | |
|             assert (
 | |
|                 get_dict[key]["const_image_summary"] == reference_results[test_run][key]
 | |
|             )
 | |
|         cuda.empty_cache()
 | |
|         summary_model = None
 | |
|         summary_vis_processors = None
 | |
| 
 | |
| 
 | |
| @pytest.mark.win_skip
 | |
| def test_analyse_questions(get_dict):
 | |
|     list_of_questions = [
 | |
|         "How many persons on the picture?",
 | |
|         "What happends on the picture?",
 | |
|     ]
 | |
|     for key in get_dict:
 | |
|         get_dict[key] = sm.SummaryDetector(
 | |
|             get_dict[key],
 | |
|             analysis_type="questions",
 | |
|             list_of_questions=list_of_questions,
 | |
|         ).analyse_image()
 | |
|     assert len(get_dict) == 2
 | |
|     list_of_questions_ans = ["2", "100"]
 | |
|     list_of_questions_ans2 = ["flood", "festival"]
 | |
|     test_answers = []
 | |
|     test_answers2 = []
 | |
|     for key in get_dict.keys():
 | |
|         test_answers.append(get_dict[key][list_of_questions[0]])
 | |
|         test_answers2.append(get_dict[key][list_of_questions[1]])
 | |
|     assert sorted(test_answers) == sorted(list_of_questions_ans)
 | |
|     assert sorted(test_answers2) == sorted(list_of_questions_ans2)
 | |
| 
 | |
| 
 | |
| def test_init_summary():
 | |
|     sd = sm.SummaryDetector({}, analysis_type="summary")
 | |
|     assert sd.analysis_type == "summary"
 | |
|     with pytest.raises(ValueError):
 | |
|         sm.SummaryDetector({}, analysis_type="something")
 | |
|     list_of_questions = ["Question 1", "Question 2"]
 | |
|     sd = sm.SummaryDetector({}, list_of_questions=list_of_questions)
 | |
|     assert sd.list_of_questions == list_of_questions
 | |
|     with pytest.raises(ValueError):
 | |
|         sm.SummaryDetector({}, list_of_questions={})
 | |
|     with pytest.raises(ValueError):
 | |
|         sm.SummaryDetector({}, list_of_questions=[None])
 | |
|     with pytest.raises(ValueError):
 | |
|         sm.SummaryDetector({}, list_of_questions=[0.1])
 | |
| 
 | |
| 
 | |
| @pytest.mark.long
 | |
| def test_advanced_init_summary():
 | |
|     sd = sm.SummaryDetector({})
 | |
|     assert sd.summary_model
 | |
|     assert sd.summary_vis_processors
 | |
|     sd = sm.SummaryDetector({}, model_type="large")
 | |
|     assert sd.summary_model
 | |
|     assert sd.summary_vis_processors
 | |
|     with pytest.raises(ValueError):
 | |
|         sm.SummaryDetector({}, model_type="bla")
 | |
|     (
 | |
|         summary_vqa_model,
 | |
|         summary_vqa_vis_processors,
 | |
|         summary_vqa_txt_processors,
 | |
|     ) = load_model_and_preprocess(
 | |
|         name="blip_vqa",
 | |
|         model_type="vqav2",
 | |
|         is_eval=True,
 | |
|         device="cpu",
 | |
|     )
 | |
|     sd = sm.SummaryDetector(
 | |
|         {},
 | |
|         summary_vqa_model=summary_vqa_model,
 | |
|         summary_vqa_vis_processors=summary_vqa_vis_processors,
 | |
|         summary_vqa_txt_processors=summary_vqa_txt_processors,
 | |
|     )
 | |
|     assert sd.summary_vqa_model
 | |
|     assert sd.summary_vqa_vis_processors
 | |
|     assert sd.summary_vqa_txt_processors
 |