зеркало из
https://github.com/ssciwr/AMMICO.git
synced 2025-10-29 21:16:06 +02:00
maintain: remove cropposts, update pyproject.toml
Этот коммит содержится в:
родитель
f2c97e26ff
Коммит
a56b57f434
@ -1,349 +0,0 @@
|
||||
import os
|
||||
import ntpath
|
||||
import cv2
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
MIN_MATCH_COUNT = 6
|
||||
FLANN_INDEX_KDTREE = 1
|
||||
|
||||
|
||||
# use this function to visualize the matches from sift
|
||||
def draw_matches(
|
||||
matches: List,
|
||||
img1: np.ndarray,
|
||||
img2: np.ndarray,
|
||||
kp1: List[cv2.KeyPoint],
|
||||
kp2: List[cv2.KeyPoint],
|
||||
) -> None:
|
||||
"""Visualize the matches from SIFT.
|
||||
|
||||
Args:
|
||||
matches (list[cv2.Match]): List of cv2.Match matches on the image.
|
||||
img1 (np.ndarray): The reference image.
|
||||
img2 (np.ndarray): The social media post.
|
||||
kp1 (list[cv2.KeyPoint]): List of keypoints from the first image.
|
||||
kp2 (list[cv2.KeyPoint]): List of keypoints from the second image.
|
||||
"""
|
||||
if len(matches) > MIN_MATCH_COUNT:
|
||||
# Estimate homography between template and scene
|
||||
src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
|
||||
M = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)[0]
|
||||
if not isinstance(M, np.ndarray):
|
||||
print("Could not match images for drawing.")
|
||||
return
|
||||
# Draw detected template in scene image
|
||||
h = img1.shape[0]
|
||||
w = img1.shape[1]
|
||||
pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(
|
||||
-1, 1, 2
|
||||
)
|
||||
dst = cv2.perspectiveTransform(pts, M)
|
||||
img2 = cv2.polylines(img2, [np.int32(dst)], True, 255, 3, cv2.LINE_AA)
|
||||
h1 = img1.shape[0]
|
||||
h2 = img2.shape[0]
|
||||
w1 = img1.shape[1]
|
||||
w2 = img2.shape[1]
|
||||
nwidth = w1 + w2
|
||||
nheight = max(h1, h2)
|
||||
hdif = int((h2 - h1) / 2)
|
||||
newimg = np.zeros((nheight, nwidth, 3), np.uint8)
|
||||
for i in range(3):
|
||||
newimg[hdif : hdif + h1, :w1, i] = img1
|
||||
newimg[:h2, w1 : w1 + w2, i] = img2
|
||||
# Draw SIFT keypoint matches
|
||||
for m in matches:
|
||||
pt1 = (int(kp1[m.queryIdx].pt[0]), int(kp1[m.queryIdx].pt[1] + hdif))
|
||||
pt2 = (int(kp2[m.trainIdx].pt[0] + w1), int(kp2[m.trainIdx].pt[1]))
|
||||
cv2.line(newimg, pt1, pt2, (255, 0, 0))
|
||||
plt.imshow(newimg)
|
||||
plt.show()
|
||||
else:
|
||||
print("Not enough matches are found - %d/%d" % (len(matches), MIN_MATCH_COUNT))
|
||||
|
||||
|
||||
def matching_points(
|
||||
img1: np.ndarray, img2: np.ndarray
|
||||
) -> Tuple[cv2.DMatch, List[cv2.KeyPoint], List[cv2.KeyPoint]]:
|
||||
"""Computes keypoint matches using the SIFT algorithm between two images.
|
||||
|
||||
Args:
|
||||
img1 (np.ndarray): The reference image.
|
||||
img2 (np.ndarray): The social media post.
|
||||
Returns:
|
||||
cv2.DMatch: List of filtered keypoint matches.
|
||||
cv2.KeyPoint: List of keypoints from the first image.
|
||||
cv2.KeyPoint: List of keypoints from the second image.
|
||||
"""
|
||||
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
sift = cv2.SIFT_create()
|
||||
kp1, des1 = sift.detectAndCompute(img1, None)
|
||||
kp2, des2 = sift.detectAndCompute(img2, None)
|
||||
|
||||
# Convert descriptors to float32
|
||||
des1 = np.float32(des1)
|
||||
des2 = np.float32(des2)
|
||||
# Initialize and use FLANN
|
||||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
||||
search_params = dict(checks=50)
|
||||
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||
matches = flann.knnMatch(des1, des2, k=2)
|
||||
filtered_matches = []
|
||||
for m, n in matches:
|
||||
# Apply ratio test to filter out ambiguous matches
|
||||
if m.distance < 0.7 * n.distance:
|
||||
filtered_matches.append(m)
|
||||
return filtered_matches, kp1, kp2
|
||||
|
||||
|
||||
def kp_from_matches(matches, kp1: np.ndarray, kp2: np.ndarray) -> Tuple[Tuple, Tuple]:
|
||||
"""Extract the match indices from the keypoints.
|
||||
|
||||
Args:
|
||||
kp1 (np.ndarray): Key points of the matches,
|
||||
kp2 (np.ndarray): Key points of the matches,
|
||||
Returns:
|
||||
tuple: Index of the descriptor in the list of train descriptors.
|
||||
tuple: index of the descriptor in the list of query descriptors.
|
||||
"""
|
||||
kp1 = np.float32([kp1[m.queryIdx].pt for m in matches])
|
||||
kp2 = np.float32([kp2[m.trainIdx].pt for m in matches])
|
||||
return kp1, kp2
|
||||
|
||||
|
||||
def compute_crop_corner(
|
||||
matches: cv2.DMatch,
|
||||
kp1: np.ndarray,
|
||||
kp2: np.ndarray,
|
||||
region: int = 30,
|
||||
h_margin: int = 0,
|
||||
v_margin: int = 5,
|
||||
min_match: int = 6,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Estimate the position on the image from where to crop.
|
||||
|
||||
Args:
|
||||
matches (cv2.DMatch): The matched objects on the image.
|
||||
kp1 (np.ndarray): Key points of the matches for the reference image.
|
||||
kp2 (np.ndarray): Key points of the matches for the social media posts.
|
||||
region (int, optional): Area to consider around the keypoints.
|
||||
Defaults to 30.
|
||||
h_margin (int, optional): Horizontal margin to subtract from the minimum
|
||||
horizontal position. Defaults to 0.
|
||||
v_margin (int, optional): Vertical margin to subtract from the minimum
|
||||
vertical position. Defaults to 5.
|
||||
min_match: Minimum number of matches required. Defaults to 6.
|
||||
Returns:
|
||||
tuple, optional: Tuple of vertical and horizontal crop corner coordinates.
|
||||
"""
|
||||
kp1, kp2 = kp_from_matches(matches, kp1, kp2)
|
||||
ys = kp2[:, 1]
|
||||
covers = []
|
||||
|
||||
# Compute the number of keypoints within the region around each y-coordinate
|
||||
for y in ys:
|
||||
ys_c = ys - y
|
||||
series = pd.Series(ys_c)
|
||||
is_between = series.between(0, region)
|
||||
covers.append(is_between.sum())
|
||||
covers = np.array(covers)
|
||||
if covers.max() < min_match:
|
||||
return None
|
||||
kp_id = ys[covers.argmax()]
|
||||
v = int(kp_id) - v_margin if int(kp_id) > v_margin else int(kp_id)
|
||||
|
||||
hs = []
|
||||
|
||||
# Find the minimum x-coordinate within the region around the selected y-coordinate
|
||||
for kp in kp2:
|
||||
if 0 <= kp[1] - v <= region:
|
||||
hs.append(kp[0])
|
||||
# do not use margin if h < image width/2, else use margin
|
||||
h = int(np.min(hs)) - h_margin if int(np.min(hs)) > h_margin else 0
|
||||
return v, h
|
||||
|
||||
|
||||
def crop_posts_image(
|
||||
ref_view: List,
|
||||
view: np.ndarray,
|
||||
) -> Union[None, Tuple[np.ndarray, int, int, int]]:
|
||||
"""Crop the social media post to exclude additional comments. Sometimes also crops the
|
||||
image part of the post - this is put back in later.
|
||||
|
||||
Args:
|
||||
ref_views (list): List of all the reference images (as numpy arrays) that signify
|
||||
below which regions should be cropped.
|
||||
view (np.ndarray): The image to crop.
|
||||
Returns:
|
||||
np.ndarray: The cropped social media post.
|
||||
"""
|
||||
filtered_matches, kp1, kp2 = matching_points(ref_view, view)
|
||||
if len(filtered_matches) < MIN_MATCH_COUNT:
|
||||
# not enough matches found
|
||||
# print("Found too few matches - {}".format(filtered_matches))
|
||||
return None
|
||||
corner = compute_crop_corner(filtered_matches, kp1, kp2)
|
||||
if corner is None:
|
||||
# no cropping corner found
|
||||
# print("Found no corner")
|
||||
return None
|
||||
v, h = corner
|
||||
# if the match is on the right-hand side of the image,
|
||||
# it is likely that there is an image to the left
|
||||
# that should not be cropped
|
||||
# in this case, we adjust the margin for the text to be
|
||||
# cropped to `correct_margin`
|
||||
# if the match is more to the left on the image, we assume
|
||||
# it starts at horizontal position 0 and do not want to
|
||||
# cut off any characters from the text
|
||||
correct_margin = 30
|
||||
if h >= view.shape[1] / 2:
|
||||
h = h - correct_margin
|
||||
else:
|
||||
h = 0
|
||||
crop_view = view[0:v, h:, :]
|
||||
return crop_view, len(filtered_matches), v, h
|
||||
|
||||
|
||||
def crop_posts_from_refs(
|
||||
ref_views: List,
|
||||
view: np.ndarray,
|
||||
plt_match: bool = False,
|
||||
plt_crop: bool = False,
|
||||
plt_image: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Crop the social media post comments from the image.
|
||||
|
||||
Args:
|
||||
ref_views (list): List of all the reference images (as numpy arrays) that signify
|
||||
below which regions should be cropped.
|
||||
view (np.ndarray): The image to crop.
|
||||
Returns:
|
||||
np.ndarray: The cropped social media post.
|
||||
"""
|
||||
crop_view = None
|
||||
# initialize the number of found matches per reference to zero
|
||||
# so later we can select the reference with the most matches
|
||||
max_matchs = 0
|
||||
rte = None
|
||||
found_match = False
|
||||
for ref_view in ref_views:
|
||||
rte = crop_posts_image(ref_view, view)
|
||||
if rte is not None:
|
||||
crop_img, match_num, v, h = rte
|
||||
if match_num > max_matchs:
|
||||
# find the reference with the most matches to crop accordingly
|
||||
crop_view = crop_img
|
||||
final_ref = ref_view
|
||||
final_v = v
|
||||
final_h = h
|
||||
max_matchs = match_num
|
||||
found_match = True
|
||||
|
||||
if found_match and plt_match:
|
||||
# plot the match
|
||||
filtered_matches, kp1, kp2 = matching_points(final_ref, view)
|
||||
img1 = cv2.cvtColor(final_ref, cv2.COLOR_BGR2GRAY)
|
||||
img2 = cv2.cvtColor(view, cv2.COLOR_BGR2GRAY)
|
||||
draw_matches(filtered_matches, img1, img2, kp1, kp2)
|
||||
|
||||
if found_match and plt_crop:
|
||||
# plot the cropped image
|
||||
view2 = view.copy()
|
||||
view2[final_v, :, 0:3] = [255, 0, 0]
|
||||
view2[:, final_h, 0:3] = [255, 0, 0]
|
||||
plt.imshow(cv2.cvtColor(view2, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
plt.imshow(cv2.cvtColor(crop_view, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
|
||||
if found_match and final_h >= view.shape[1] / 2:
|
||||
# here it would crop the actual image from the social media post
|
||||
# to avoid this, we check the position from where it would crop
|
||||
# if > than half of the width of the image, also keep all that is
|
||||
# on the left-hand side of the crop
|
||||
crop_post = crop_image_from_post(view, final_h)
|
||||
if plt_image:
|
||||
# plot the image part of the social media post
|
||||
plt.imshow(cv2.cvtColor(crop_post, cv2.COLOR_BGR2RGB))
|
||||
plt.show()
|
||||
# now concatenate the image and the text part
|
||||
crop_view = paste_image_and_comment(crop_post, crop_view)
|
||||
return crop_view
|
||||
|
||||
|
||||
def crop_image_from_post(view: np.ndarray, final_h: int) -> np.ndarray:
|
||||
"""Crop the image part from the social media post.
|
||||
|
||||
Args:
|
||||
view (np.ndarray): The image to be cropped.
|
||||
final_h: The horizontal position up to which should be cropped.
|
||||
Returns:
|
||||
np.ndarray: The cropped image part."""
|
||||
crop_post = view[:, 0:final_h, :]
|
||||
return crop_post
|
||||
|
||||
|
||||
def paste_image_and_comment(crop_post: np.ndarray, crop_view: np.ndarray) -> np.ndarray:
|
||||
"""Paste the image part and the text part together without the unecessary comments.
|
||||
|
||||
Args:
|
||||
crop_post (np.ndarray): The cropped image part of the social media post.
|
||||
crop_view (np.ndarray): The cropped text part of the social media post.
|
||||
Returns:
|
||||
np.ndarray: The image and text part of the social media post in one image."""
|
||||
h1, w1 = crop_post.shape[:2]
|
||||
h2, w2 = crop_view.shape[:2]
|
||||
image_all = np.zeros((max(h1, h2), w1 + w2, 3), np.uint8)
|
||||
image_all[:h1, :w1, :3] = crop_post
|
||||
image_all[:h2, w1 : w1 + w2, :3] = crop_view
|
||||
return image_all
|
||||
|
||||
|
||||
def crop_media_posts(
|
||||
files, ref_files, save_crop_dir, plt_match=False, plt_crop=False, plt_image=False
|
||||
) -> None:
|
||||
"""Crop social media posts so that comments beyond the first comment/post are cut off.
|
||||
|
||||
Args:
|
||||
files (list): List of all the files to be cropped.
|
||||
ref_files (list): List of all the reference images that signify
|
||||
below which regions should be cropped.
|
||||
save_crop_dir (str): Directory where to write the cropped social media posts to.
|
||||
plt_match (Bool, optional): Display the matched areas on the social media post.
|
||||
Defaults to False.
|
||||
plt_crop (Bool, optional): Display the cropped text part of the social media post.
|
||||
Defaults to False.
|
||||
plt_image (Bool, optional): Display the image part of the social media post.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
# get the reference images with regions that signify areas to crop
|
||||
ref_views = []
|
||||
for ref_file in ref_files.values():
|
||||
ref_file_path = ref_file["filename"]
|
||||
ref_view = cv2.imread(ref_file_path)
|
||||
ref_views.append(ref_view)
|
||||
# parse through the social media posts to be cropped
|
||||
for crop_file in files.values():
|
||||
crop_file_path = crop_file["filename"]
|
||||
view = cv2.imread(crop_file_path)
|
||||
print("Doing file {}".format(crop_file_path))
|
||||
crop_view = crop_posts_from_refs(
|
||||
ref_views,
|
||||
view,
|
||||
plt_match=plt_match,
|
||||
plt_crop=plt_crop,
|
||||
plt_image=plt_image,
|
||||
)
|
||||
if crop_view is not None:
|
||||
# save the image to the provided folder
|
||||
filename = ntpath.basename(crop_file_path)
|
||||
save_path = os.path.join(save_crop_dir, filename)
|
||||
save_path = save_path.replace("\\", "/")
|
||||
cv2.imwrite(save_path, crop_view)
|
||||
@ -1,88 +0,0 @@
|
||||
import ammico.cropposts as crpo
|
||||
import cv2
|
||||
import pytest
|
||||
import numpy as np
|
||||
import ammico.utils as utils
|
||||
|
||||
|
||||
TEST_IMAGE_1 = "crop_test_files/pic1.png"
|
||||
TEST_IMAGE_2 = "crop_test_ref_files/pic2.png"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def open_images(get_path):
|
||||
ref_view = cv2.imread(get_path + TEST_IMAGE_2)
|
||||
view = cv2.imread(get_path + TEST_IMAGE_1)
|
||||
return ref_view, view
|
||||
|
||||
|
||||
def test_matching_points(open_images):
|
||||
filtered_matches, _, _ = crpo.matching_points(open_images[0], open_images[1])
|
||||
assert len(filtered_matches) > 0
|
||||
|
||||
|
||||
def test_kp_from_matches(open_images):
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(open_images[0], open_images[1])
|
||||
kp1, kp2 = crpo.kp_from_matches(filtered_matches, kp1, kp2)
|
||||
assert kp1.shape[0] == len(filtered_matches)
|
||||
assert kp2.shape[0] == len(filtered_matches)
|
||||
assert kp1.shape[1] == 2
|
||||
assert kp2.shape[1] == 2
|
||||
|
||||
|
||||
def test_compute_crop_corner(open_images):
|
||||
filtered_matches, kp1, kp2 = crpo.matching_points(open_images[0], open_images[1])
|
||||
corner = crpo.compute_crop_corner(filtered_matches, kp1, kp2)
|
||||
assert corner is not None
|
||||
v, h = corner
|
||||
assert 0 <= v < open_images[1].shape[0]
|
||||
assert 0 <= h < open_images[1].shape[0]
|
||||
|
||||
|
||||
def test_crop_posts_image(open_images):
|
||||
rte = crpo.crop_posts_image(open_images[0], open_images[1])
|
||||
assert rte is not None
|
||||
crop_view, match_num, _, _ = rte
|
||||
assert match_num > 0
|
||||
assert (
|
||||
crop_view.shape[0] * crop_view.shape[1]
|
||||
<= open_images[1].shape[0] * open_images[1].shape[1]
|
||||
)
|
||||
|
||||
|
||||
def test_crop_posts_from_refs(open_images):
|
||||
crop_view = crpo.crop_posts_from_refs([open_images[0]], open_images[1])
|
||||
assert (
|
||||
crop_view.shape[0] * crop_view.shape[1]
|
||||
<= open_images[1].shape[0] * open_images[1].shape[1]
|
||||
)
|
||||
|
||||
|
||||
def test_crop_image_from_post(open_images):
|
||||
crop_post = crpo.crop_image_from_post(open_images[0], 4)
|
||||
ref_array = np.array(
|
||||
[[220, 202, 155], [221, 204, 155], [221, 204, 155], [221, 204, 155]],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
assert np.array_equal(crop_post[0], ref_array)
|
||||
|
||||
|
||||
def test_paste_image_and_comment(open_images):
|
||||
full_post = crpo.paste_image_and_comment(open_images[0], open_images[1])
|
||||
ref_array1 = np.array([220, 202, 155], dtype=np.uint8)
|
||||
ref_array2 = np.array([74, 76, 64], dtype=np.uint8)
|
||||
assert np.array_equal(full_post[0, 0], ref_array1)
|
||||
assert np.array_equal(full_post[-1, -1], ref_array2)
|
||||
|
||||
|
||||
def test_crop_media_posts(get_path, tmp_path):
|
||||
print(get_path)
|
||||
files = utils.find_files(path=get_path + "crop_test_files/")
|
||||
ref_files = utils.find_files(path=get_path + "crop_test_ref_files/")
|
||||
crpo.crop_media_posts(files, ref_files, tmp_path)
|
||||
assert len(list(tmp_path.iterdir())) == 1
|
||||
# now check that image in tmp_path is the cropped one
|
||||
filename = tmp_path / "pic1.png"
|
||||
cropped_image = cv2.imread(str(filename))
|
||||
ref = np.array([222, 205, 156], dtype=np.uint8)
|
||||
assert np.array_equal(cropped_image[0, 0], ref)
|
||||
@ -34,27 +34,21 @@ dependencies = [
|
||||
"matplotlib",
|
||||
"nbval",
|
||||
"numpy<=1.23.4",
|
||||
"opencv",
|
||||
"pandas",
|
||||
"peft<=0.13.0",
|
||||
"Pillow",
|
||||
"pooch",
|
||||
"protobuf",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"Requests",
|
||||
"retina_face",
|
||||
"ammico-lavis>=1.0.2.3",
|
||||
"huggingface-hub<=0.25.2",
|
||||
"setuptools",
|
||||
"spacy<=3.7.5",
|
||||
"tensorflow>=2.13.0",
|
||||
"torch<2.6.0",
|
||||
"tensorflow",
|
||||
"google-cloud-vision",
|
||||
"dash_bootstrap_components",
|
||||
"colorgram.py",
|
||||
"webcolors>1.13",
|
||||
"colour-science",
|
||||
"scikit-learn>1.3.0",
|
||||
"tqdm"
|
||||
]
|
||||
|
||||
|
||||
Загрузка…
x
Ссылка в новой задаче
Block a user